HoneyTian commited on
Commit
ed91efa
·
1 Parent(s): c255825

add dfnet2

Browse files
examples/dfnet2/run.sh ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name dfnet-nx-speech \
6
+ --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
7
+ --speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
8
+
9
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet-nx-dns3 \
10
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
12
+
13
+
14
+ END
15
+
16
+
17
+ # params
18
+ system_version="windows";
19
+ verbose=true;
20
+ stage=0 # start from 0 if you need to start from data preparation
21
+ stop_stage=9
22
+
23
+ work_dir="$(pwd)"
24
+ file_folder_name=file_folder_name
25
+ final_model_name=final_model_name
26
+ config_file="yaml/config.yaml"
27
+ limit=10
28
+
29
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
30
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
31
+
32
+ max_count=10000000
33
+
34
+ nohup_name=nohup.out
35
+
36
+ # model params
37
+ batch_size=64
38
+ max_epochs=200
39
+ save_top_k=10
40
+ patience=5
41
+
42
+
43
+ # parse options
44
+ while true; do
45
+ [ -z "${1:-}" ] && break; # break if there are no arguments
46
+ case "$1" in
47
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
48
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
49
+ old_value="(eval echo \\$$name)";
50
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
51
+ was_bool=true;
52
+ else
53
+ was_bool=false;
54
+ fi
55
+
56
+ # Set the variable to the right value-- the escaped quotes make it work if
57
+ # the option had spaces, like --cmd "queue.pl -sync y"
58
+ eval "${name}=\"$2\"";
59
+
60
+ # Check that Boolean-valued arguments are really Boolean.
61
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
62
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
63
+ exit 1;
64
+ fi
65
+ shift 2;
66
+ ;;
67
+
68
+ *) break;
69
+ esac
70
+ done
71
+
72
+ file_dir="${work_dir}/${file_folder_name}"
73
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
74
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
75
+
76
+ train_dataset="${file_dir}/train.jsonl"
77
+ valid_dataset="${file_dir}/valid.jsonl"
78
+
79
+ $verbose && echo "system_version: ${system_version}"
80
+ $verbose && echo "file_folder_name: ${file_folder_name}"
81
+
82
+ if [ $system_version == "windows" ]; then
83
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
84
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
85
+ #source /data/local/bin/nx_denoise/bin/activate
86
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
87
+ fi
88
+
89
+
90
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
91
+ $verbose && echo "stage 1: prepare data"
92
+ cd "${work_dir}" || exit 1
93
+ python3 step_1_prepare_data.py \
94
+ --file_dir "${file_dir}" \
95
+ --noise_dir "${noise_dir}" \
96
+ --speech_dir "${speech_dir}" \
97
+ --train_dataset "${train_dataset}" \
98
+ --valid_dataset "${valid_dataset}" \
99
+ --max_count "${max_count}" \
100
+
101
+ fi
102
+
103
+
104
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
105
+ $verbose && echo "stage 2: train model"
106
+ cd "${work_dir}" || exit 1
107
+ python3 step_2_train_model.py \
108
+ --train_dataset "${train_dataset}" \
109
+ --valid_dataset "${valid_dataset}" \
110
+ --serialization_dir "${file_dir}" \
111
+ --config_file "${config_file}" \
112
+
113
+ fi
114
+
115
+
116
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
117
+ $verbose && echo "stage 3: test model"
118
+ cd "${work_dir}" || exit 1
119
+ python3 step_3_evaluation.py \
120
+ --valid_dataset "${valid_dataset}" \
121
+ --model_dir "${file_dir}/best" \
122
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
123
+ --limit "${limit}" \
124
+
125
+ fi
126
+
127
+
128
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
129
+ $verbose && echo "stage 4: collect files"
130
+ cd "${work_dir}" || exit 1
131
+
132
+ mkdir -p ${final_model_dir}
133
+
134
+ cp "${file_dir}/best"/* "${final_model_dir}"
135
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
136
+
137
+ cd "${final_model_dir}/.." || exit 1;
138
+
139
+ if [ -e "${final_model_name}.zip" ]; then
140
+ rm -rf "${final_model_name}_backup.zip"
141
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
142
+ fi
143
+
144
+ zip -r "${final_model_name}.zip" "${final_model_name}"
145
+ rm -rf "${final_model_name}"
146
+
147
+ fi
148
+
149
+
150
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
151
+ $verbose && echo "stage 5: clear file_dir"
152
+ cd "${work_dir}" || exit 1
153
+
154
+ rm -rf "${file_dir}";
155
+
156
+ fi
examples/dfnet2/step_1_prepare_data.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--file_dir", default="./", type=str)
21
+
22
+ parser.add_argument(
23
+ "--noise_dir",
24
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--speech_dir",
29
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
30
+ type=str
31
+ )
32
+
33
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
34
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
35
+
36
+ parser.add_argument("--duration", default=4.0, type=float)
37
+ parser.add_argument("--min_snr_db", default=-10, type=float)
38
+ parser.add_argument("--max_snr_db", default=20, type=float)
39
+
40
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
41
+
42
+ parser.add_argument("--max_count", default=10000, type=int)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def filename_generator(data_dir: str):
49
+ data_dir = Path(data_dir)
50
+ for filename in data_dir.glob("**/*.wav"):
51
+ yield filename.as_posix()
52
+
53
+
54
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
55
+ data_dir = Path(data_dir)
56
+ for epoch_idx in range(max_epoch):
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ if np.sum(signal[begin: begin+win_size]) == 0:
71
+ continue
72
+ row = {
73
+ "epoch_idx": epoch_idx,
74
+ "filename": filename.as_posix(),
75
+ "raw_duration": round(raw_duration, 4),
76
+ "offset": round(begin / sample_rate, 4),
77
+ "duration": round(duration, 4),
78
+ }
79
+ yield row
80
+
81
+
82
+ def main():
83
+ args = get_args()
84
+
85
+ file_dir = Path(args.file_dir)
86
+ file_dir.mkdir(exist_ok=True)
87
+
88
+ noise_dir = Path(args.noise_dir)
89
+ speech_dir = Path(args.speech_dir)
90
+
91
+ noise_generator = target_second_signal_generator(
92
+ noise_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate,
95
+ max_epoch=100000,
96
+ )
97
+ speech_generator = target_second_signal_generator(
98
+ speech_dir.as_posix(),
99
+ duration=args.duration,
100
+ sample_rate=args.target_sample_rate,
101
+ max_epoch=1,
102
+ )
103
+
104
+ dataset = list()
105
+
106
+ count = 0
107
+ process_bar = tqdm(desc="build dataset jsonl")
108
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
+ for noise, speech in zip(noise_generator, speech_generator):
110
+ if count >= args.max_count > 0:
111
+ break
112
+
113
+ noise_filename = noise["filename"]
114
+ noise_raw_duration = noise["raw_duration"]
115
+ noise_offset = noise["offset"]
116
+ noise_duration = noise["duration"]
117
+
118
+ speech_filename = speech["filename"]
119
+ speech_raw_duration = speech["raw_duration"]
120
+ speech_offset = speech["offset"]
121
+ speech_duration = speech["duration"]
122
+
123
+ random1 = random.random()
124
+ random2 = random.random()
125
+
126
+ row = {
127
+ "count": count,
128
+
129
+ "noise_filename": noise_filename,
130
+ "noise_raw_duration": noise_raw_duration,
131
+ "noise_offset": noise_offset,
132
+ "noise_duration": noise_duration,
133
+
134
+ "speech_filename": speech_filename,
135
+ "speech_raw_duration": speech_raw_duration,
136
+ "speech_offset": speech_offset,
137
+ "speech_duration": speech_duration,
138
+
139
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
140
+
141
+ "random1": random1,
142
+ }
143
+ row = json.dumps(row, ensure_ascii=False)
144
+ if random2 < (1 / 300 / 1):
145
+ fvalid.write(f"{row}\n")
146
+ else:
147
+ ftrain.write(f"{row}\n")
148
+
149
+ count += 1
150
+ duration_seconds = count * args.duration
151
+ duration_hours = duration_seconds / 3600
152
+
153
+ process_bar.update(n=1)
154
+ process_bar.set_postfix({
155
+ # "duration_seconds": round(duration_seconds, 4),
156
+ "duration_hours": round(duration_hours, 4),
157
+
158
+ })
159
+
160
+ return
161
+
162
+
163
+ if __name__ == "__main__":
164
+ main()
examples/dfnet2/step_2_train_model.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/Rikorose/DeepFilterNet
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ from fontTools.varLib.plot import stops
19
+
20
+ pwd = os.path.abspath(os.path.dirname(__file__))
21
+ sys.path.append(os.path.join(pwd, "../../"))
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch.nn import functional as F
27
+ from torch.utils.data.dataloader import DataLoader
28
+ from tqdm import tqdm
29
+
30
+ from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
31
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
32
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
33
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
34
+ from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config
35
+ from toolbox.torchaudio.models.dfnet2.modeling_dfnet2 import DfNet2, DfNet2PretrainedModel
36
+
37
+
38
+ def get_args():
39
+ parser = argparse.ArgumentParser()
40
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
41
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
42
+
43
+ parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
44
+ parser.add_argument("--patience", default=10, type=int)
45
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
46
+
47
+ parser.add_argument("--config_file", default="config.yaml", type=str)
48
+
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def logging_config(file_dir: str):
54
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
55
+
56
+ logging.basicConfig(format=fmt,
57
+ datefmt="%m/%d/%Y %H:%M:%S",
58
+ level=logging.INFO)
59
+ file_handler = TimedRotatingFileHandler(
60
+ filename=os.path.join(file_dir, "main.log"),
61
+ encoding="utf-8",
62
+ when="D",
63
+ interval=1,
64
+ backupCount=7
65
+ )
66
+ file_handler.setLevel(logging.INFO)
67
+ file_handler.setFormatter(logging.Formatter(fmt))
68
+ logger = logging.getLogger(__name__)
69
+ logger.addHandler(file_handler)
70
+
71
+ return logger
72
+
73
+
74
+ class CollateFunction(object):
75
+ def __init__(self):
76
+ pass
77
+
78
+ def __call__(self, batch: List[dict]):
79
+ clean_audios = list()
80
+ noisy_audios = list()
81
+ snr_db_list = list()
82
+
83
+ for sample in batch:
84
+ # noise_wave: torch.Tensor = sample["noise_wave"]
85
+ clean_audio: torch.Tensor = sample["speech_wave"]
86
+ noisy_audio: torch.Tensor = sample["mix_wave"]
87
+ # snr_db: float = sample["snr_db"]
88
+
89
+ clean_audios.append(clean_audio)
90
+ noisy_audios.append(noisy_audio)
91
+
92
+ clean_audios = torch.stack(clean_audios)
93
+ noisy_audios = torch.stack(noisy_audios)
94
+
95
+ # assert
96
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
97
+ raise AssertionError("nan or inf in clean_audios")
98
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
99
+ raise AssertionError("nan or inf in noisy_audios")
100
+ return clean_audios, noisy_audios
101
+
102
+
103
+ collate_fn = CollateFunction()
104
+
105
+
106
+ def main():
107
+ args = get_args()
108
+
109
+ config = DfNet2Config.from_pretrained(
110
+ pretrained_model_name_or_path=args.config_file,
111
+ )
112
+
113
+ serialization_dir = Path(args.serialization_dir)
114
+ serialization_dir.mkdir(parents=True, exist_ok=True)
115
+
116
+ logger = logging_config(serialization_dir)
117
+
118
+ random.seed(config.seed)
119
+ np.random.seed(config.seed)
120
+ torch.manual_seed(config.seed)
121
+ logger.info(f"set seed: {config.seed}")
122
+
123
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
+ n_gpu = torch.cuda.device_count()
125
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
126
+
127
+ # datasets
128
+ train_dataset = DenoiseJsonlDataset(
129
+ jsonl_file=args.train_dataset,
130
+ expected_sample_rate=config.sample_rate,
131
+ max_wave_value=32768.0,
132
+ min_snr_db=config.min_snr_db,
133
+ max_snr_db=config.max_snr_db,
134
+ # skip=225000,
135
+ )
136
+ valid_dataset = DenoiseJsonlDataset(
137
+ jsonl_file=args.valid_dataset,
138
+ expected_sample_rate=config.sample_rate,
139
+ max_wave_value=32768.0,
140
+ min_snr_db=config.min_snr_db,
141
+ max_snr_db=config.max_snr_db,
142
+ )
143
+ train_data_loader = DataLoader(
144
+ dataset=train_dataset,
145
+ batch_size=config.batch_size,
146
+ # shuffle=True,
147
+ sampler=None,
148
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
149
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
150
+ collate_fn=collate_fn,
151
+ pin_memory=False,
152
+ prefetch_factor=None if platform.system() == "Windows" else 2,
153
+ )
154
+ valid_data_loader = DataLoader(
155
+ dataset=valid_dataset,
156
+ batch_size=config.batch_size,
157
+ # shuffle=True,
158
+ sampler=None,
159
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
160
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
161
+ collate_fn=collate_fn,
162
+ pin_memory=False,
163
+ prefetch_factor=None if platform.system() == "Windows" else 2,
164
+ )
165
+
166
+ # models
167
+ logger.info(f"prepare models. config_file: {args.config_file}")
168
+ model = DfNet2PretrainedModel(config).to(device)
169
+ model.to(device)
170
+ model.train()
171
+
172
+ # optimizer
173
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
174
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr)
175
+
176
+ # resume training
177
+ last_step_idx = -1
178
+ last_epoch = -1
179
+ for step_idx_str in serialization_dir.glob("steps-*"):
180
+ step_idx_str = Path(step_idx_str)
181
+ step_idx = step_idx_str.stem.split("-")[1]
182
+ step_idx = int(step_idx)
183
+ if step_idx > last_step_idx:
184
+ last_step_idx = step_idx
185
+ # last_epoch = 1
186
+
187
+ if last_step_idx != -1:
188
+ logger.info(f"resume from steps-{last_step_idx}.")
189
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
190
+
191
+ logger.info(f"load state dict for model.")
192
+ with open(model_pt.as_posix(), "rb") as f:
193
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
194
+ model.load_state_dict(state_dict, strict=True)
195
+
196
+ if config.lr_scheduler == "CosineAnnealingLR":
197
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
198
+ optimizer,
199
+ last_epoch=last_epoch,
200
+ # T_max=10 * config.eval_steps,
201
+ # eta_min=0.01 * config.lr,
202
+ **config.lr_scheduler_kwargs,
203
+ )
204
+ elif config.lr_scheduler == "MultiStepLR":
205
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
206
+ optimizer,
207
+ last_epoch=last_epoch,
208
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
209
+ )
210
+ else:
211
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
212
+
213
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
214
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
215
+ fft_size_list=[256, 512, 1024],
216
+ win_size_list=[256, 512, 1024],
217
+ hop_size_list=[128, 256, 512],
218
+ factor_sc=1.5,
219
+ factor_mag=1.0,
220
+ reduction="mean"
221
+ ).to(device)
222
+
223
+ # training loop
224
+
225
+ # state
226
+ average_pesq_score = 1000000000
227
+ average_loss = 1000000000
228
+ average_mr_stft_loss = 1000000000
229
+ average_neg_si_snr_loss = 1000000000
230
+ average_mask_loss = 1000000000
231
+ average_lsnr_loss = 1000000000
232
+
233
+ model_list = list()
234
+ best_epoch_idx = None
235
+ best_step_idx = None
236
+ best_metric = None
237
+ patience_count = 0
238
+
239
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
240
+
241
+ logger.info("training")
242
+ early_stop_flag = False
243
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
244
+ if early_stop_flag:
245
+ break
246
+
247
+ # train
248
+ model.train()
249
+
250
+ total_pesq_score = 0.
251
+ total_loss = 0.
252
+ total_mr_stft_loss = 0.
253
+ total_neg_si_snr_loss = 0.
254
+ total_mask_loss = 0.
255
+ total_lsnr_loss = 0.
256
+ total_batches = 0.
257
+
258
+ progress_bar_train = tqdm(
259
+ initial=step_idx,
260
+ desc="Training; epoch-{}".format(epoch_idx),
261
+ )
262
+ for train_batch in train_data_loader:
263
+ clean_audios, noisy_audios = train_batch
264
+ clean_audios: torch.Tensor = clean_audios.to(device)
265
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
266
+
267
+ est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
268
+
269
+ mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
270
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
271
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
272
+ lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
273
+
274
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss
275
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
276
+ logger.info(f"find nan or inf in loss.")
277
+ continue
278
+
279
+ denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
280
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
281
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
282
+
283
+ optimizer.zero_grad()
284
+ loss.backward()
285
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
286
+ optimizer.step()
287
+ lr_scheduler.step()
288
+
289
+ total_pesq_score += pesq_score
290
+ total_loss += loss.item()
291
+ total_mr_stft_loss += mr_stft_loss.item()
292
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
293
+ total_mask_loss += mask_loss.item()
294
+ total_lsnr_loss += lsnr_loss.item()
295
+ total_batches += 1
296
+
297
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
298
+ average_loss = round(total_loss / total_batches, 4)
299
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
300
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
301
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
302
+ average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
303
+
304
+ progress_bar_train.update(1)
305
+ progress_bar_train.set_postfix({
306
+ "lr": lr_scheduler.get_last_lr()[0],
307
+ "pesq_score": average_pesq_score,
308
+ "loss": average_loss,
309
+ "mr_stft_loss": average_mr_stft_loss,
310
+ "neg_si_snr_loss": average_neg_si_snr_loss,
311
+ "mask_loss": average_mask_loss,
312
+ "lsnr_loss": average_lsnr_loss,
313
+ })
314
+
315
+ # evaluation
316
+ step_idx += 1
317
+ if step_idx % config.eval_steps == 0:
318
+ with torch.no_grad():
319
+ torch.cuda.empty_cache()
320
+
321
+ total_pesq_score = 0.
322
+ total_loss = 0.
323
+ total_mr_stft_loss = 0.
324
+ total_neg_si_snr_loss = 0.
325
+ total_mask_loss = 0.
326
+ total_lsnr_loss = 0.
327
+ total_batches = 0.
328
+
329
+ progress_bar_train.close()
330
+ progress_bar_eval = tqdm(
331
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
332
+ )
333
+ for eval_batch in valid_data_loader:
334
+ clean_audios, noisy_audios = eval_batch
335
+ clean_audios: torch.Tensor = clean_audios.to(device)
336
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
337
+
338
+ est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
339
+
340
+ mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
341
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
342
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
343
+ lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
344
+
345
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss
346
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
347
+ logger.info(f"find nan or inf in loss.")
348
+ continue
349
+
350
+ denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
351
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
352
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
353
+
354
+ total_pesq_score += pesq_score
355
+ total_loss += loss.item()
356
+ total_mr_stft_loss += mr_stft_loss.item()
357
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
358
+ total_mask_loss += mask_loss.item()
359
+ total_lsnr_loss += lsnr_loss.item()
360
+ total_batches += 1
361
+
362
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
363
+ average_loss = round(total_loss / total_batches, 4)
364
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
365
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
366
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
367
+ average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
368
+
369
+ progress_bar_eval.update(1)
370
+ progress_bar_eval.set_postfix({
371
+ "lr": lr_scheduler.get_last_lr()[0],
372
+ "pesq_score": average_pesq_score,
373
+ "loss": average_loss,
374
+ "mr_stft_loss": average_mr_stft_loss,
375
+ "neg_si_snr_loss": average_neg_si_snr_loss,
376
+ "mask_loss": average_mask_loss,
377
+ "lsnr_loss": average_lsnr_loss,
378
+ })
379
+
380
+ total_pesq_score = 0.
381
+ total_loss = 0.
382
+ total_mr_stft_loss = 0.
383
+ total_neg_si_snr_loss = 0.
384
+ total_mask_loss = 0.
385
+ total_lsnr_loss = 0.
386
+ total_batches = 0.
387
+
388
+ progress_bar_eval.close()
389
+ progress_bar_train = tqdm(
390
+ initial=progress_bar_train.n,
391
+ postfix=progress_bar_train.postfix,
392
+ desc=progress_bar_train.desc,
393
+ )
394
+
395
+ # save path
396
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
397
+ save_dir.mkdir(parents=True, exist_ok=False)
398
+
399
+ # save models
400
+ model.save_pretrained(save_dir.as_posix())
401
+
402
+ model_list.append(save_dir)
403
+ if len(model_list) >= args.num_serialized_models_to_keep:
404
+ model_to_delete: Path = model_list.pop(0)
405
+ shutil.rmtree(model_to_delete.as_posix())
406
+
407
+ # save metric
408
+ if best_metric is None:
409
+ best_epoch_idx = epoch_idx
410
+ best_step_idx = step_idx
411
+ best_metric = average_pesq_score
412
+ elif average_pesq_score >= best_metric:
413
+ # great is better.
414
+ best_epoch_idx = epoch_idx
415
+ best_step_idx = step_idx
416
+ best_metric = average_pesq_score
417
+ else:
418
+ pass
419
+
420
+ metrics = {
421
+ "epoch_idx": epoch_idx,
422
+ "best_epoch_idx": best_epoch_idx,
423
+ "best_step_idx": best_step_idx,
424
+ "pesq_score": average_pesq_score,
425
+ "loss": average_loss,
426
+ "mr_stft_loss": average_mr_stft_loss,
427
+ "neg_si_snr_loss": average_neg_si_snr_loss,
428
+ "mask_loss": average_mask_loss,
429
+ "lsnr_loss": average_lsnr_loss,
430
+ }
431
+ metrics_filename = save_dir / "metrics_epoch.json"
432
+ with open(metrics_filename, "w", encoding="utf-8") as f:
433
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
434
+
435
+ # save best
436
+ best_dir = serialization_dir / "best"
437
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
438
+ if best_dir.exists():
439
+ shutil.rmtree(best_dir)
440
+ shutil.copytree(save_dir, best_dir)
441
+
442
+ # early stop
443
+ early_stop_flag = False
444
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
445
+ patience_count = 0
446
+ else:
447
+ patience_count += 1
448
+ if patience_count >= args.patience:
449
+ early_stop_flag = True
450
+
451
+ # early stop
452
+ if early_stop_flag:
453
+ break
454
+
455
+ return
456
+
457
+
458
+ if __name__ == "__main__":
459
+ main()
examples/dfnet2/yaml/config.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "dfnet2"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ nfft: 512
6
+ win_size: 200
7
+ hop_size: 80
8
+
9
+ spec_bins: 256
10
+
11
+ # model
12
+ conv_channels: 64
13
+ conv_kernel_size_input:
14
+ - 3
15
+ - 3
16
+ conv_kernel_size_inner:
17
+ - 1
18
+ - 3
19
+ convt_kernel_size_inner:
20
+ - 1
21
+ - 3
22
+
23
+ embedding_hidden_size: 256
24
+ encoder_combine_op: "concat"
25
+
26
+ encoder_emb_skip_op: "none"
27
+ encoder_emb_linear_groups: 16
28
+ encoder_emb_hidden_size: 256
29
+
30
+ encoder_linear_groups: 32
31
+
32
+ decoder_emb_num_layers: 3
33
+ decoder_emb_skip_op: "none"
34
+ decoder_emb_linear_groups: 16
35
+ decoder_emb_hidden_size: 256
36
+
37
+ df_decoder_hidden_size: 256
38
+ df_num_layers: 2
39
+ df_order: 5
40
+ df_bins: 96
41
+ df_gru_skip: "grouped_linear"
42
+ df_decoder_linear_groups: 16
43
+ df_pathway_kernel_size_t: 5
44
+ df_lookahead: 2
45
+
46
+ # lsnr
47
+ n_frame: 3
48
+ lsnr_max: 30
49
+ lsnr_min: -15
50
+ norm_tau: 1.
51
+
52
+ # data
53
+ min_snr_db: -10
54
+ max_snr_db: 20
55
+
56
+ # train
57
+ lr: 0.001
58
+ lr_scheduler: "CosineAnnealingLR"
59
+ lr_scheduler_kwargs:
60
+ T_max: 250000
61
+ eta_min: 0.0001
62
+
63
+ max_epochs: 100
64
+ clip_grad_norm: 10.0
65
+ seed: 1234
66
+
67
+ num_workers: 8
68
+ batch_size: 64
69
+ eval_steps: 10000
70
+
71
+ # runtime
72
+ use_post_filter: true
examples/test.py DELETED
@@ -1,39 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import torch
4
- import torch.nn as nn
5
-
6
-
7
- inputs = torch.randn(size=(1, 1, 16000))
8
-
9
- conv1d = nn.Conv1d(
10
- in_channels=1,
11
- out_channels=1,
12
- kernel_size=3,
13
- stride=2,
14
- padding=0,
15
- dilation=1,
16
- )
17
- conv1dt = nn.ConvTranspose1d(
18
- in_channels=1,
19
- out_channels=1,
20
- kernel_size=3,
21
- stride=2,
22
- padding=0,
23
- output_padding=1,
24
- dilation=1,
25
- )
26
-
27
- x = conv1d.forward(inputs)
28
-
29
- print(x.shape)
30
-
31
- x = conv1dt.forward(x)
32
- print(x.shape)
33
- print(x[:, :, 0])
34
- print(x[:, :, -2])
35
- print(x[:, :, -1])
36
-
37
-
38
- if __name__ == "__main__":
39
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/dfnet/modeling_dfnet_online.py DELETED
@@ -1,226 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- DeepFilterNet 的原生实现不直接支持流式推理
5
-
6
- 社区开发者(如 Rikorose)提供了基于 Torch 的流式推理实现
7
- https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF
8
-
9
- 此文件试图实现一个支持流式推理的 dfnet
10
-
11
- """
12
- import os
13
- import math
14
- from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
15
-
16
- import numpy as np
17
- import torch
18
- import torch.nn as nn
19
- from torch.nn import functional as F
20
-
21
- from toolbox.torchaudio.configuration_utils import CONFIG_FILE
22
- from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
23
- from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
24
- from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
25
- from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands
26
-
27
-
28
- MODEL_FILE = "model.pt"
29
-
30
-
31
- norm_layer_dict = {
32
- "batch_norm_2d": torch.nn.BatchNorm2d
33
- }
34
-
35
-
36
- activation_layer_dict = {
37
- "relu": torch.nn.ReLU,
38
- "identity": torch.nn.Identity,
39
- "sigmoid": torch.nn.Sigmoid,
40
- }
41
-
42
-
43
- class CausalConv2d(nn.Module):
44
- def __init__(self,
45
- in_channels: int,
46
- out_channels: int,
47
- kernel_size: Union[int, Iterable[int]],
48
- fstride: int = 1,
49
- dilation: int = 1,
50
- pad_f_dim: bool = True,
51
- bias: bool = True,
52
- separable: bool = False,
53
- norm_layer: str = "batch_norm_2d",
54
- activation_layer: str = "relu",
55
- ):
56
- super(CausalConv2d, self).__init__()
57
- kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
58
-
59
- if pad_f_dim:
60
- fpad = kernel_size[1] // 2 + dilation - 1
61
- else:
62
- fpad = 0
63
-
64
- # for last 2 dim, pad (left, right, top, bottom).
65
- self.lookback = kernel_size[0] - 1
66
- if self.lookback > 0:
67
- self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0)
68
- else:
69
- self.tpad = nn.Identity()
70
-
71
- groups = math.gcd(in_channels, out_channels) if separable else 1
72
- if groups == 1:
73
- separable = False
74
- if max(kernel_size) == 1:
75
- separable = False
76
-
77
- self.conv = nn.Conv2d(
78
- in_channels,
79
- out_channels,
80
- kernel_size=kernel_size,
81
- padding=(0, fpad),
82
- stride=(1, fstride), # stride over time is always 1
83
- dilation=(1, dilation), # dilation over time is always 1
84
- groups=groups,
85
- bias=bias,
86
- )
87
-
88
- if separable:
89
- self.convp = nn.Conv2d(
90
- out_channels,
91
- out_channels,
92
- kernel_size=1,
93
- bias=False,
94
- )
95
- else:
96
- self.convp = nn.Identity()
97
-
98
- if norm_layer is not None:
99
- norm_layer = norm_layer_dict[norm_layer]
100
- self.norm = norm_layer(out_channels)
101
- else:
102
- self.norm = nn.Identity()
103
-
104
- if activation_layer is not None:
105
- activation_layer = activation_layer_dict[activation_layer]
106
- self.activation = activation_layer()
107
- else:
108
- self.activation = nn.Identity()
109
-
110
- super().__init__()
111
-
112
- def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None):
113
- """
114
- :param inputs: shape: [b, c, t, f]
115
- :param cache: shape: [b, c, lookback, f];
116
- :return:
117
- """
118
- x = inputs
119
-
120
- if cache is None:
121
- x = self.tpad(x)
122
- else:
123
- x = torch.concat(tensors=[cache, x], dim=2)
124
- new_cache = x[:, :, -self.lookback:, :]
125
-
126
- x = self.conv(x)
127
-
128
- x = self.convp(x)
129
- x = self.norm(x)
130
- x = self.activation(x)
131
-
132
- return x, new_cache
133
-
134
-
135
- class CausalConvTranspose2d(nn.Module):
136
- def __init__(self,
137
- in_channels: int,
138
- out_channels: int,
139
- kernel_size: Union[int, Iterable[int]],
140
- fstride: int = 1,
141
- dilation: int = 1,
142
- pad_f_dim: bool = True,
143
- bias: bool = True,
144
- separable: bool = False,
145
- norm_layer: str = "batch_norm_2d",
146
- activation_layer: str = "relu",
147
- ):
148
- super(CausalConvTranspose2d, self).__init__()
149
-
150
- kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
151
-
152
- if pad_f_dim:
153
- fpad = kernel_size[1] // 2
154
- else:
155
- fpad = 0
156
-
157
- # for last 2 dim, pad (left, right, top, bottom).
158
- self.lookback = kernel_size[0] - 1
159
-
160
- groups = math.gcd(in_channels, out_channels) if separable else 1
161
- if groups == 1:
162
- separable = False
163
-
164
- self.convt = nn.ConvTranspose2d(
165
- in_channels,
166
- out_channels,
167
- kernel_size=kernel_size,
168
- padding=(0, fpad),
169
- output_padding=(0, 0),
170
- stride=(1, fstride), # stride over time is always 1
171
- dilation=(1, dilation), # dilation over time is always 1
172
- groups=groups,
173
- bias=bias,
174
- )
175
-
176
- if separable:
177
- self.convp = nn.Conv2d(
178
- out_channels,
179
- out_channels,
180
- kernel_size=1,
181
- bias=False,
182
- )
183
- else:
184
- self.convp = nn.Identity()
185
-
186
- if norm_layer is not None:
187
- norm_layer = norm_layer_dict[norm_layer]
188
- self.norm = norm_layer(out_channels)
189
- else:
190
- self.norm = nn.Identity()
191
-
192
- if activation_layer is not None:
193
- activation_layer = activation_layer_dict[activation_layer]
194
- self.activation = activation_layer()
195
- else:
196
- self.activation = nn.Identity()
197
-
198
- def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None):
199
- """
200
- :param inputs: shape: [b, c, t, f]
201
- :param cache: shape: [b, c, lookback, f];
202
- :return:
203
- """
204
- x = inputs
205
-
206
- # x shape: [b, c, t, f]
207
- x = self.convt(x)
208
- # x shape: [b, c, t+lookback, f]
209
-
210
- if cache is not None:
211
- x = torch.concat(tensors=[
212
- x[:, :, :self.lookback, :] + cache,
213
- x[:, :, self.lookback:, :]
214
- ], dim=2)
215
- x = x[:, :, :-self.lookback, :]
216
- new_cache = x[:, :, -self.lookback:, :]
217
-
218
- x = self.convp(x)
219
- x = self.norm(x)
220
- x = self.activation(x)
221
-
222
- return x, new_cache
223
-
224
-
225
- if __name__ == "__main__":
226
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/dfnet2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/dfnet2/configuration_dfnet2.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Tuple
4
+
5
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class DfNet2Config(PretrainedConfig):
9
+ def __init__(self,
10
+ sample_rate: int = 8000,
11
+ nfft: int = 512,
12
+ win_size: int = 200,
13
+ hop_size: int = 80,
14
+ win_type: str = "hann",
15
+
16
+ spec_bins: int = 256,
17
+ erb_bins: int = 32,
18
+ min_freq_bins_for_erb: int = 2,
19
+
20
+ conv_channels: int = 64,
21
+ conv_kernel_size_input: Tuple[int, int] = (3, 3),
22
+ conv_kernel_size_inner: Tuple[int, int] = (1, 3),
23
+
24
+ convt_kernel_size_inner: Tuple[int, int] = (1, 3),
25
+
26
+ embedding_hidden_size: int = 256,
27
+ encoder_combine_op: str = "concat",
28
+
29
+ encoder_emb_skip_op: str = "none",
30
+ encoder_emb_linear_groups: int = 16,
31
+ encoder_emb_hidden_size: int = 256,
32
+
33
+ encoder_linear_groups: int = 32,
34
+
35
+ decoder_emb_num_layers: int = 3,
36
+ decoder_emb_skip_op: str = "none",
37
+ decoder_emb_linear_groups: int = 16,
38
+ decoder_emb_hidden_size: int = 256,
39
+
40
+ df_decoder_hidden_size: int = 256,
41
+ df_num_layers: int = 2,
42
+ df_order: int = 5,
43
+ df_bins: int = 96,
44
+ df_gru_skip: str = "grouped_linear",
45
+ df_decoder_linear_groups: int = 16,
46
+ df_pathway_kernel_size_t: int = 5,
47
+ df_lookahead: int = 2,
48
+
49
+ n_frame: int = 3,
50
+ max_local_snr: int = 30,
51
+ min_local_snr: int = -15,
52
+ norm_tau: float = 1.,
53
+
54
+ min_snr_db: float = -10,
55
+ max_snr_db: float = 20,
56
+
57
+ lr: float = 0.001,
58
+ lr_scheduler: str = "CosineAnnealingLR",
59
+ lr_scheduler_kwargs: dict = None,
60
+
61
+ max_epochs: int = 100,
62
+ clip_grad_norm: float = 10.,
63
+ seed: int = 1234,
64
+
65
+ num_workers: int = 4,
66
+ batch_size: int = 4,
67
+ eval_steps: int = 25000,
68
+
69
+ use_post_filter: bool = False,
70
+
71
+ **kwargs
72
+ ):
73
+ super(DfNet2Config, self).__init__(**kwargs)
74
+ # transform
75
+ self.sample_rate = sample_rate
76
+ self.nfft = nfft
77
+ self.win_size = win_size
78
+ self.hop_size = hop_size
79
+ self.win_type = win_type
80
+
81
+ # spectrum
82
+ self.spec_bins = spec_bins
83
+ self.erb_bins = erb_bins
84
+ self.min_freq_bins_for_erb = min_freq_bins_for_erb
85
+
86
+ # conv
87
+ self.conv_channels = conv_channels
88
+ self.conv_kernel_size_input = conv_kernel_size_input
89
+ self.conv_kernel_size_inner = conv_kernel_size_inner
90
+
91
+ self.convt_kernel_size_inner = convt_kernel_size_inner
92
+
93
+ self.embedding_hidden_size = embedding_hidden_size
94
+
95
+ # encoder
96
+ self.encoder_emb_skip_op = encoder_emb_skip_op
97
+ self.encoder_emb_linear_groups = encoder_emb_linear_groups
98
+ self.encoder_emb_hidden_size = encoder_emb_hidden_size
99
+
100
+ self.encoder_linear_groups = encoder_linear_groups
101
+ self.encoder_combine_op = encoder_combine_op
102
+
103
+ # decoder
104
+ self.decoder_emb_num_layers = decoder_emb_num_layers
105
+ self.decoder_emb_skip_op = decoder_emb_skip_op
106
+ self.decoder_emb_linear_groups = decoder_emb_linear_groups
107
+ self.decoder_emb_hidden_size = decoder_emb_hidden_size
108
+
109
+ # df decoder
110
+ self.df_decoder_hidden_size = df_decoder_hidden_size
111
+ self.df_num_layers = df_num_layers
112
+ self.df_order = df_order
113
+ self.df_bins = df_bins
114
+ self.df_gru_skip = df_gru_skip
115
+ self.df_decoder_linear_groups = df_decoder_linear_groups
116
+ self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
117
+ self.df_lookahead = df_lookahead
118
+
119
+ # lsnr
120
+ self.n_frame = n_frame
121
+ self.max_local_snr = max_local_snr
122
+ self.min_local_snr = min_local_snr
123
+ self.norm_tau = norm_tau
124
+
125
+ # data snr
126
+ self.min_snr_db = min_snr_db
127
+ self.max_snr_db = max_snr_db
128
+
129
+ # train
130
+ self.lr = lr
131
+ self.lr_scheduler = lr_scheduler
132
+ self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
133
+
134
+ self.max_epochs = max_epochs
135
+ self.clip_grad_norm = clip_grad_norm
136
+ self.seed = seed
137
+
138
+ self.num_workers = num_workers
139
+ self.batch_size = batch_size
140
+ self.eval_steps = eval_steps
141
+
142
+ # runtime
143
+ self.use_post_filter = use_post_filter
144
+
145
+
146
+ if __name__ == "__main__":
147
+ pass
toolbox/torchaudio/models/dfnet2/inference_dfnet2.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import logging
4
+ from pathlib import Path
5
+ import shutil
6
+ import tempfile, time
7
+ import zipfile
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+
14
+ torch.set_num_threads(1)
15
+
16
+ from project_settings import project_path
17
+ from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
18
+ from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNetPretrainedModel, MODEL_FILE
19
+
20
+ logger = logging.getLogger("toolbox")
21
+
22
+
23
+ class InferenceDfNet(object):
24
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
25
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
26
+ self.device = torch.device(device)
27
+
28
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
29
+ config, model = self.load_models(self.pretrained_model_path_or_zip_file)
30
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
31
+
32
+ self.config = config
33
+ self.model = model
34
+ self.model.to(device)
35
+ self.model.eval()
36
+
37
+ def load_models(self, model_path: str):
38
+ model_path = Path(model_path)
39
+ if model_path.name.endswith(".zip"):
40
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
41
+ out_root = Path(tempfile.gettempdir()) / "nx_denoise"
42
+ out_root.mkdir(parents=True, exist_ok=True)
43
+ f_zip.extractall(path=out_root)
44
+ model_path = out_root / model_path.stem
45
+
46
+ config = DfNetConfig.from_pretrained(
47
+ pretrained_model_name_or_path=model_path.as_posix(),
48
+ )
49
+ model = DfNetPretrainedModel.from_pretrained(
50
+ pretrained_model_name_or_path=model_path.as_posix(),
51
+ )
52
+ model.to(self.device)
53
+ model.eval()
54
+
55
+ shutil.rmtree(model_path)
56
+ return config, model
57
+
58
+ def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
59
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
60
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
61
+
62
+ # noisy_audio shape: [batch_size, n_samples]
63
+ enhanced_audio = self.enhancement_by_tensor(noisy_audio)
64
+ # enhanced_audio shape: [channels, num_samples]
65
+ enhanced_audio = enhanced_audio[0]
66
+ # enhanced_audio shape: [num_samples]
67
+ return enhanced_audio.cpu().numpy()
68
+
69
+ def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
70
+ if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
71
+ raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
72
+
73
+ # noisy_audio shape: [batch_size, num_samples]
74
+ noisy_audios = noisy_audio.to(self.device)
75
+
76
+ with torch.no_grad():
77
+ est_spec, est_wav, est_mask, lsnr = self.model.forward(noisy_audios)
78
+
79
+ # shape: [batch_size, num_samples]
80
+ enhanced_audio = torch.unsqueeze(est_wav, dim=1)
81
+ # shape: [batch_size, 1, num_samples]
82
+
83
+ enhanced_audio = enhanced_audio[0]
84
+ # shape: [channels, num_samples]
85
+ return enhanced_audio
86
+
87
+
88
+ def main():
89
+ model_zip_file = project_path / "trained_models/dfnet-nx-dns3.zip"
90
+ infer_model = InferenceDfNet(model_zip_file)
91
+
92
+ sample_rate = 8000
93
+ noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_3.wav"
94
+ noisy_audio, sample_rate = librosa.load(
95
+ noisy_audio_file.as_posix(),
96
+ sr=sample_rate,
97
+ )
98
+ duration = librosa.get_duration(y=noisy_audio, sr=sample_rate)
99
+ # noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
100
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
101
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
102
+
103
+ begin = time.time()
104
+ enhanced_audio = infer_model.enhancement_by_tensor(noisy_audio)
105
+ time_cost = time.time() - begin
106
+ print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
107
+
108
+ filename = "enhanced_audio.wav"
109
+ torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
110
+
111
+ return
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py ADDED
@@ -0,0 +1,1364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ DeepFilterNet 的原生实现不直接支持流式推理
5
+
6
+ 社区开发者(如 Rikorose)提供了基于 Torch 的流式推理实现
7
+ https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF
8
+
9
+ 此文件试图实现一个支持流式推理的 dfnet
10
+
11
+ """
12
+ import os
13
+ import math
14
+ from collections import defaultdict
15
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import functional as F
21
+
22
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
23
+ from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config
24
+ from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
25
+ from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
26
+ from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands
27
+
28
+
29
+ MODEL_FILE = "model.pt"
30
+
31
+
32
+ norm_layer_dict = {
33
+ "batch_norm_2d": torch.nn.BatchNorm2d
34
+ }
35
+
36
+
37
+ activation_layer_dict = {
38
+ "relu": torch.nn.ReLU,
39
+ "identity": torch.nn.Identity,
40
+ "sigmoid": torch.nn.Sigmoid,
41
+ }
42
+
43
+
44
+ class CausalConv2d(nn.Module):
45
+ def __init__(self,
46
+ in_channels: int,
47
+ out_channels: int,
48
+ kernel_size: Union[int, Iterable[int]],
49
+ fstride: int = 1,
50
+ dilation: int = 1,
51
+ pad_f_dim: bool = True,
52
+ bias: bool = True,
53
+ separable: bool = False,
54
+ norm_layer: str = "batch_norm_2d",
55
+ activation_layer: str = "relu",
56
+ ):
57
+ super(CausalConv2d, self).__init__()
58
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
59
+
60
+ if pad_f_dim:
61
+ fpad = kernel_size[1] // 2 + dilation - 1
62
+ else:
63
+ fpad = 0
64
+
65
+ # for last 2 dim, pad (left, right, top, bottom).
66
+ self.lookback = kernel_size[0] - 1
67
+ if self.lookback > 0:
68
+ self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0)
69
+ else:
70
+ self.tpad = nn.Identity()
71
+
72
+ groups = math.gcd(in_channels, out_channels) if separable else 1
73
+ if groups == 1:
74
+ separable = False
75
+ if max(kernel_size) == 1:
76
+ separable = False
77
+
78
+ self.conv = nn.Conv2d(
79
+ in_channels,
80
+ out_channels,
81
+ kernel_size=kernel_size,
82
+ padding=(0, fpad),
83
+ stride=(1, fstride), # stride over time is always 1
84
+ dilation=(1, dilation), # dilation over time is always 1
85
+ groups=groups,
86
+ bias=bias,
87
+ )
88
+
89
+ if separable:
90
+ self.convp = nn.Conv2d(
91
+ out_channels,
92
+ out_channels,
93
+ kernel_size=1,
94
+ bias=False,
95
+ )
96
+ else:
97
+ self.convp = nn.Identity()
98
+
99
+ if norm_layer is not None:
100
+ norm_layer = norm_layer_dict[norm_layer]
101
+ self.norm = norm_layer(out_channels)
102
+ else:
103
+ self.norm = nn.Identity()
104
+
105
+ if activation_layer is not None:
106
+ activation_layer = activation_layer_dict[activation_layer]
107
+ self.activation = activation_layer()
108
+ else:
109
+ self.activation = nn.Identity()
110
+
111
+ def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None):
112
+ """
113
+ :param inputs: shape: [b, c, t, f]
114
+ :param cache: shape: [b, c, lookback, f];
115
+ :return:
116
+ """
117
+ x = inputs
118
+
119
+ if cache is None:
120
+ x = self.tpad(x)
121
+ else:
122
+ x = torch.concat(tensors=[cache, x], dim=2)
123
+
124
+ new_cache = None
125
+ if self.lookback > 0:
126
+ new_cache = x[:, :, -self.lookback:, :]
127
+
128
+ x = self.conv(x)
129
+
130
+ x = self.convp(x)
131
+ x = self.norm(x)
132
+ x = self.activation(x)
133
+
134
+ return x, new_cache
135
+
136
+
137
+ class CausalConvTranspose2d(nn.Module):
138
+ def __init__(self,
139
+ in_channels: int,
140
+ out_channels: int,
141
+ kernel_size: Union[int, Iterable[int]],
142
+ fstride: int = 1,
143
+ dilation: int = 1,
144
+ pad_f_dim: bool = True,
145
+ bias: bool = True,
146
+ separable: bool = False,
147
+ norm_layer: str = "batch_norm_2d",
148
+ activation_layer: str = "relu",
149
+ ):
150
+ super(CausalConvTranspose2d, self).__init__()
151
+
152
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
153
+
154
+ if pad_f_dim:
155
+ fpad = kernel_size[1] // 2
156
+ else:
157
+ fpad = 0
158
+
159
+ # for last 2 dim, pad (left, right, top, bottom).
160
+ self.lookback = kernel_size[0] - 1
161
+
162
+ groups = math.gcd(in_channels, out_channels) if separable else 1
163
+ if groups == 1:
164
+ separable = False
165
+
166
+ self.convt = nn.ConvTranspose2d(
167
+ in_channels,
168
+ out_channels,
169
+ kernel_size=kernel_size,
170
+ padding=(0, fpad),
171
+ output_padding=(0, fpad),
172
+ stride=(1, fstride), # stride over time is always 1
173
+ dilation=(1, dilation), # dilation over time is always 1
174
+ groups=groups,
175
+ bias=bias,
176
+ )
177
+
178
+ if separable:
179
+ self.convp = nn.Conv2d(
180
+ out_channels,
181
+ out_channels,
182
+ kernel_size=1,
183
+ bias=False,
184
+ )
185
+ else:
186
+ self.convp = nn.Identity()
187
+
188
+ if norm_layer is not None:
189
+ norm_layer = norm_layer_dict[norm_layer]
190
+ self.norm = norm_layer(out_channels)
191
+ else:
192
+ self.norm = nn.Identity()
193
+
194
+ if activation_layer is not None:
195
+ activation_layer = activation_layer_dict[activation_layer]
196
+ self.activation = activation_layer()
197
+ else:
198
+ self.activation = nn.Identity()
199
+
200
+ def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None):
201
+ """
202
+ :param inputs: shape: [b, c, t, f]
203
+ :param cache: shape: [b, c, lookback, f];
204
+ :return:
205
+ """
206
+ x = inputs
207
+
208
+ # x shape: [b, c, t, f]
209
+ x = self.convt(x)
210
+ # x shape: [b, c, t+lookback, f]
211
+
212
+ new_cache = None
213
+ if self.lookback > 0:
214
+ if cache is not None:
215
+ x = torch.concat(tensors=[
216
+ x[:, :, :self.lookback, :] + cache,
217
+ x[:, :, self.lookback:, :]
218
+ ], dim=2)
219
+
220
+ x = x[:, :, :-self.lookback, :]
221
+ new_cache = x[:, :, -self.lookback:, :]
222
+
223
+ x = self.convp(x)
224
+ x = self.norm(x)
225
+ x = self.activation(x)
226
+
227
+ return x, new_cache
228
+
229
+
230
+ class GroupedLinear(nn.Module):
231
+
232
+ def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
233
+ super().__init__()
234
+ # self.weight: Tensor
235
+ self.input_size = input_size
236
+ self.hidden_size = hidden_size
237
+ self.groups = groups
238
+ assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}"
239
+ assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}"
240
+ self.ws = input_size // groups
241
+ self.register_parameter(
242
+ "weight",
243
+ torch.nn.Parameter(
244
+ torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
245
+ ),
246
+ )
247
+ self.reset_parameters()
248
+
249
+ def reset_parameters(self):
250
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
251
+
252
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
253
+ # x: [..., I]
254
+ b, t, f = x.shape
255
+ if f != self.input_size:
256
+ raise AssertionError
257
+
258
+ # new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
259
+ new_shape = (b, t, self.groups, self.ws)
260
+ x = x.view(new_shape)
261
+ # The better way, but not supported by torchscript
262
+ # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
263
+ x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
264
+ x = x.flatten(2, 3)
265
+ # x: [b, t, h]
266
+ return x
267
+
268
+ def __repr__(self):
269
+ cls = self.__class__.__name__
270
+ return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})"
271
+
272
+
273
+ class SqueezedGRU_S(nn.Module):
274
+ """
275
+ SGE net: Video object detection with squeezed GRU and information entropy map
276
+ https://arxiv.org/abs/2106.07224
277
+ """
278
+
279
+ def __init__(
280
+ self,
281
+ input_size: int,
282
+ hidden_size: int,
283
+ output_size: Optional[int] = None,
284
+ num_layers: int = 1,
285
+ linear_groups: int = 8,
286
+ batch_first: bool = True,
287
+ skip_op: str = "none",
288
+ activation_layer: str = "identity",
289
+ ):
290
+ super().__init__()
291
+ self.input_size = input_size
292
+ self.hidden_size = hidden_size
293
+
294
+ self.linear_in = nn.Sequential(
295
+ GroupedLinear(
296
+ input_size=input_size,
297
+ hidden_size=hidden_size,
298
+ groups=linear_groups,
299
+ ),
300
+ activation_layer_dict[activation_layer](),
301
+ )
302
+
303
+ # gru skip operator
304
+ self.gru_skip_op = None
305
+
306
+ if skip_op == "none":
307
+ self.gru_skip_op = None
308
+ elif skip_op == "identity":
309
+ if not input_size != output_size:
310
+ raise AssertionError("Dimensions do not match")
311
+ self.gru_skip_op = nn.Identity()
312
+ elif skip_op == "grouped_linear":
313
+ self.gru_skip_op = GroupedLinear(
314
+ input_size=hidden_size,
315
+ hidden_size=hidden_size,
316
+ groups=linear_groups,
317
+ )
318
+ else:
319
+ raise NotImplementedError()
320
+
321
+ self.gru = nn.GRU(
322
+ input_size=hidden_size,
323
+ hidden_size=hidden_size,
324
+ num_layers=num_layers,
325
+ batch_first=batch_first,
326
+ bidirectional=False,
327
+ )
328
+
329
+ if output_size is not None:
330
+ self.linear_out = nn.Sequential(
331
+ GroupedLinear(
332
+ input_size=hidden_size,
333
+ hidden_size=output_size,
334
+ groups=linear_groups,
335
+ ),
336
+ activation_layer_dict[activation_layer](),
337
+ )
338
+ else:
339
+ self.linear_out = nn.Identity()
340
+
341
+ def forward(self, inputs: torch.Tensor, hx: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
342
+ # inputs: shape: [b, t, h]
343
+ x = self.linear_in.forward(inputs)
344
+
345
+ x, hx = self.gru.forward(x, hx)
346
+
347
+ x = self.linear_out(x)
348
+
349
+ if self.gru_skip_op is not None:
350
+ x = x + self.gru_skip_op(inputs)
351
+
352
+ return x, hx
353
+
354
+
355
+ class Add(nn.Module):
356
+ def forward(self, a, b):
357
+ return a + b
358
+
359
+
360
+ class Concat(nn.Module):
361
+ def forward(self, a, b):
362
+ return torch.cat((a, b), dim=-1)
363
+
364
+
365
+ class Encoder(nn.Module):
366
+ def __init__(self, config: DfNet2Config):
367
+ super(Encoder, self).__init__()
368
+ self.embedding_input_size = config.conv_channels * config.erb_bins // 4
369
+ self.embedding_output_size = config.conv_channels * config.erb_bins // 4
370
+ self.embedding_hidden_size = config.embedding_hidden_size
371
+
372
+ self.spec_conv0 = CausalConv2d(
373
+ in_channels=1,
374
+ out_channels=config.conv_channels,
375
+ kernel_size=config.conv_kernel_size_input,
376
+ bias=False,
377
+ separable=True,
378
+ fstride=1,
379
+ )
380
+ self.spec_conv1 = CausalConv2d(
381
+ in_channels=config.conv_channels,
382
+ out_channels=config.conv_channels,
383
+ kernel_size=config.conv_kernel_size_inner,
384
+ bias=False,
385
+ separable=True,
386
+ fstride=2,
387
+ )
388
+ self.spec_conv2 = CausalConv2d(
389
+ in_channels=config.conv_channels,
390
+ out_channels=config.conv_channels,
391
+ kernel_size=config.conv_kernel_size_inner,
392
+ bias=False,
393
+ separable=True,
394
+ fstride=2,
395
+ )
396
+ self.spec_conv3 = CausalConv2d(
397
+ in_channels=config.conv_channels,
398
+ out_channels=config.conv_channels,
399
+ kernel_size=config.conv_kernel_size_inner,
400
+ bias=False,
401
+ separable=True,
402
+ fstride=1,
403
+ )
404
+
405
+ self.df_conv0 = CausalConv2d(
406
+ in_channels=2,
407
+ out_channels=config.conv_channels,
408
+ kernel_size=config.conv_kernel_size_input,
409
+ bias=False,
410
+ separable=True,
411
+ fstride=1,
412
+ )
413
+ self.df_conv1 = CausalConv2d(
414
+ in_channels=config.conv_channels,
415
+ out_channels=config.conv_channels,
416
+ kernel_size=config.conv_kernel_size_inner,
417
+ bias=False,
418
+ separable=True,
419
+ fstride=2,
420
+ )
421
+ self.df_fc_emb = nn.Sequential(
422
+ GroupedLinear(
423
+ config.conv_channels * config.df_bins // 2,
424
+ self.embedding_input_size,
425
+ groups=config.encoder_linear_groups
426
+ ),
427
+ nn.ReLU(inplace=True)
428
+ )
429
+
430
+ if config.encoder_combine_op == "concat":
431
+ self.embedding_input_size *= 2
432
+ self.combine = Concat()
433
+ else:
434
+ self.combine = Add()
435
+
436
+ # emb_gru
437
+ if config.spec_bins % 8 != 0:
438
+ raise AssertionError("spec_bins should be divisible by 8")
439
+
440
+ self.emb_gru = SqueezedGRU_S(
441
+ self.embedding_input_size,
442
+ self.embedding_hidden_size,
443
+ output_size=self.embedding_output_size,
444
+ num_layers=1,
445
+ batch_first=True,
446
+ skip_op=config.encoder_emb_skip_op,
447
+ linear_groups=config.encoder_emb_linear_groups,
448
+ activation_layer="relu",
449
+ )
450
+
451
+ # lsnr
452
+ self.lsnr_fc = nn.Sequential(
453
+ nn.Linear(self.embedding_output_size, 1),
454
+ nn.Sigmoid()
455
+ )
456
+ self.lsnr_scale = config.max_local_snr - config.min_local_snr
457
+ self.lsnr_offset = config.min_local_snr
458
+
459
+ def forward(self,
460
+ feat_erb: torch.Tensor,
461
+ feat_spec: torch.Tensor,
462
+ cache_dict: dict = None,
463
+ ):
464
+ if cache_dict is None:
465
+ cache_dict = defaultdict(lambda: None)
466
+ cache0 = cache_dict["cache0"]
467
+ cache1 = cache_dict["cache1"]
468
+ cache2 = cache_dict["cache2"]
469
+ cache3 = cache_dict["cache3"]
470
+ cache4 = cache_dict["cache4"]
471
+ cache5 = cache_dict["cache5"]
472
+ cache6 = cache_dict["cache6"]
473
+
474
+ # feat_erb shape: (b, 1, t, erb_bins)
475
+ e0, new_cache0 = self.spec_conv0.forward(feat_erb, cache=cache0)
476
+ e1, new_cache1 = self.spec_conv1.forward(e0, cache=cache1)
477
+ e2, new_cache2 = self.spec_conv2.forward(e1, cache=cache2)
478
+ e3, new_cache3 = self.spec_conv3.forward(e2, cache=cache3)
479
+ # e0 shape: [b, c, t, erb_bins]
480
+ # e1 shape: [b, c, t, erb_bins // 2]
481
+ # e2 shape: [b, c, t, erb_bins // 4]
482
+ # e3 shape: [b, c, t, erb_bins // 4]
483
+ # e3 shape: [b, 64, t, 32/4=8]
484
+
485
+ # feat_spec, shape: (b, 2, t, df_bins)
486
+ c0, new_cache4 = self.df_conv0.forward(feat_spec, cache=cache4)
487
+ c1, new_cache5 = self.df_conv1.forward(c0, cache=cache5)
488
+ # c0 shape: [b, c, t, df_bins]
489
+ # c1 shape: [b, c, t, df_bins // 2]
490
+ # c1 shape: [b, 64, t, 96/2=48]
491
+
492
+ cemb = c1.permute(0, 2, 3, 1)
493
+ # cemb shape: [b, t, df_bins // 2, c]
494
+ cemb = cemb.flatten(2)
495
+ # cemb shape: [b, t, df_bins // 2 * c]
496
+ # cemb shape: [b, t, 96/2*64=3072]
497
+ cemb = self.df_fc_emb.forward(cemb)
498
+ # cemb shape: [b, t, erb_bins // 4 * c]
499
+ # cemb shape: [b, t, 32/4*64=512]
500
+
501
+ # e3 shape: [b, c, t, erb_bins // 4]
502
+ emb = e3.permute(0, 2, 3, 1)
503
+ # emb shape: [b, t, erb_bins // 4, c]
504
+ emb = emb.flatten(2)
505
+ # emb shape: [b, t, erb_bins // 4 * c]
506
+ # emb shape: [b, t, 32/4*64=512]
507
+
508
+ emb = self.combine(emb, cemb)
509
+ # if concat; emb shape: [b, t, spec_bins // 4 * c * 2]
510
+ # if add; emb shape: [b, t, spec_bins // 4 * c]
511
+
512
+ emb, new_cache6 = self.emb_gru.forward(emb, hx=cache6)
513
+
514
+ # emb shape: [b, t, spec_dim // 4 * c]
515
+ # h shape: [b, 1, spec_dim]
516
+
517
+ lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
518
+ # lsnr shape: [b, t, 1]
519
+
520
+ new_cache_dict = {
521
+ "cache0": new_cache0,
522
+ "cache1": new_cache1,
523
+ "cache2": new_cache2,
524
+ "cache3": new_cache3,
525
+ "cache4": new_cache4,
526
+ "cache5": new_cache5,
527
+ "cache6": new_cache6,
528
+ }
529
+ return e0, e1, e2, e3, emb, c0, lsnr, new_cache_dict
530
+
531
+
532
+ class ErbDecoder(nn.Module):
533
+ def __init__(self, config: DfNet2Config):
534
+ super(ErbDecoder, self).__init__()
535
+
536
+ if config.spec_bins % 8 != 0:
537
+ raise AssertionError("spec_bins should be divisible by 8")
538
+
539
+ self.emb_in_dim = config.conv_channels * config.erb_bins // 4
540
+ self.emb_out_dim = config.conv_channels * config.erb_bins // 4
541
+ self.emb_hidden_dim = config.decoder_emb_hidden_size
542
+
543
+ self.emb_gru = SqueezedGRU_S(
544
+ self.emb_in_dim,
545
+ self.emb_hidden_dim,
546
+ output_size=self.emb_out_dim,
547
+ num_layers=config.decoder_emb_num_layers - 1,
548
+ batch_first=True,
549
+ skip_op=config.decoder_emb_skip_op,
550
+ linear_groups=config.decoder_emb_linear_groups,
551
+ activation_layer="relu",
552
+ )
553
+ self.conv3p = CausalConv2d(
554
+ in_channels=config.conv_channels,
555
+ out_channels=config.conv_channels,
556
+ kernel_size=1,
557
+ bias=False,
558
+ separable=True,
559
+ fstride=1,
560
+ )
561
+ self.convt3 = CausalConv2d(
562
+ in_channels=config.conv_channels,
563
+ out_channels=config.conv_channels,
564
+ kernel_size=config.conv_kernel_size_inner,
565
+ bias=False,
566
+ separable=True,
567
+ fstride=1,
568
+ )
569
+ self.conv2p = CausalConv2d(
570
+ in_channels=config.conv_channels,
571
+ out_channels=config.conv_channels,
572
+ kernel_size=1,
573
+ bias=False,
574
+ separable=True,
575
+ fstride=1,
576
+ )
577
+ self.convt2 = CausalConvTranspose2d(
578
+ in_channels=config.conv_channels,
579
+ out_channels=config.conv_channels,
580
+ kernel_size=config.convt_kernel_size_inner,
581
+ bias=False,
582
+ separable=True,
583
+ fstride=2,
584
+ )
585
+ self.conv1p = CausalConv2d(
586
+ in_channels=config.conv_channels,
587
+ out_channels=config.conv_channels,
588
+ kernel_size=1,
589
+ bias=False,
590
+ separable=True,
591
+ fstride=1,
592
+ )
593
+ self.convt1 = CausalConvTranspose2d(
594
+ in_channels=config.conv_channels,
595
+ out_channels=config.conv_channels,
596
+ kernel_size=config.convt_kernel_size_inner,
597
+ bias=False,
598
+ separable=True,
599
+ fstride=2,
600
+ )
601
+ self.conv0p = CausalConv2d(
602
+ in_channels=config.conv_channels,
603
+ out_channels=config.conv_channels,
604
+ kernel_size=1,
605
+ bias=False,
606
+ separable=True,
607
+ fstride=1,
608
+ )
609
+ self.conv0_out = CausalConv2d(
610
+ in_channels=config.conv_channels,
611
+ out_channels=1,
612
+ kernel_size=config.conv_kernel_size_inner,
613
+ activation_layer="sigmoid",
614
+ bias=False,
615
+ separable=True,
616
+ fstride=1,
617
+ )
618
+
619
+ def forward(self, emb, e3, e2, e1, e0, cache_dict: dict = None) -> torch.Tensor:
620
+ if cache_dict is None:
621
+ cache_dict = defaultdict(lambda: None)
622
+ cache0 = cache_dict["cache0"]
623
+ cache1 = cache_dict["cache1"]
624
+ cache2 = cache_dict["cache2"]
625
+ cache3 = cache_dict["cache3"]
626
+ cache4 = cache_dict["cache4"]
627
+
628
+ # Estimates erb mask
629
+ b, _, t, f8 = e3.shape
630
+
631
+ # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels]
632
+ emb, new_cache0 = self.emb_gru.forward(emb, hx=cache0)
633
+ # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
634
+ emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2)
635
+
636
+ e3, new_cache1 = self.convt3.forward(self.conv3p(e3)[0] + emb, cache=cache1)
637
+ # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
638
+ e2, new_cache2 = self.convt2.forward(self.conv2p(e2)[0] + e3, cache=cache2)
639
+ # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2]
640
+ e1, new_cache3 = self.convt1.forward(self.conv1p(e1)[0] + e2, cache=cache3)
641
+ # e1 shape: [batch_size, conv_channels, time_steps, freq_dim]
642
+ mask, new_cache4 = self.conv0_out.forward(self.conv0p(e0)[0] + e1, cache=cache4)
643
+ # mask shape: [batch_size, 1, time_steps, freq_dim]
644
+
645
+ new_cache_dict = {
646
+ "cache0": new_cache0,
647
+ "cache1": new_cache1,
648
+ "cache2": new_cache2,
649
+ "cache3": new_cache3,
650
+ "cache4": new_cache4,
651
+ }
652
+ return mask, new_cache_dict
653
+
654
+
655
+ class DfDecoder(nn.Module):
656
+ def __init__(self, config: DfNet2Config):
657
+ super(DfDecoder, self).__init__()
658
+
659
+ self.embedding_input_size = config.conv_channels * config.erb_bins // 4
660
+ self.df_decoder_hidden_size = config.df_decoder_hidden_size
661
+ self.df_num_layers = config.df_num_layers
662
+
663
+ self.df_order = config.df_order
664
+
665
+ self.df_bins = config.df_bins
666
+ self.df_out_ch = config.df_order * 2
667
+
668
+ self.df_convp = CausalConv2d(
669
+ config.conv_channels,
670
+ self.df_out_ch,
671
+ fstride=1,
672
+ kernel_size=(config.df_pathway_kernel_size_t, 1),
673
+ separable=True,
674
+ bias=False,
675
+ )
676
+ self.df_gru = SqueezedGRU_S(
677
+ self.embedding_input_size,
678
+ self.df_decoder_hidden_size,
679
+ num_layers=self.df_num_layers,
680
+ batch_first=True,
681
+ skip_op="none",
682
+ activation_layer="relu",
683
+ )
684
+
685
+ if config.df_gru_skip == "none":
686
+ self.df_skip = None
687
+ elif config.df_gru_skip == "identity":
688
+ if config.embedding_hidden_size != config.df_decoder_hidden_size:
689
+ raise AssertionError("Dimensions do not match")
690
+ self.df_skip = nn.Identity()
691
+ elif config.df_gru_skip == "grouped_linear":
692
+ self.df_skip = GroupedLinear(
693
+ self.embedding_input_size,
694
+ self.df_decoder_hidden_size,
695
+ groups=config.df_decoder_linear_groups
696
+ )
697
+ else:
698
+ raise NotImplementedError()
699
+
700
+ self.df_out: nn.Module
701
+ out_dim = self.df_bins * self.df_out_ch
702
+
703
+ self.df_out = nn.Sequential(
704
+ GroupedLinear(
705
+ input_size=self.df_decoder_hidden_size,
706
+ hidden_size=out_dim,
707
+ groups=config.df_decoder_linear_groups,
708
+ # groups = self.df_bins // 5,
709
+ ),
710
+ nn.Tanh()
711
+ )
712
+ self.df_fc_a = nn.Sequential(
713
+ nn.Linear(self.df_decoder_hidden_size, 1),
714
+ nn.Sigmoid()
715
+ )
716
+
717
+ def forward(self, emb: torch.Tensor, c0: torch.Tensor, cache_dict: dict = None) -> torch.Tensor:
718
+ if cache_dict is None:
719
+ cache_dict = defaultdict(lambda: None)
720
+ cache0 = cache_dict["cache0"]
721
+ cache1 = cache_dict["cache1"]
722
+
723
+ # emb shape: [batch_size, time_steps, df_bins // 4 * channels]
724
+ b, t, _ = emb.shape
725
+ df_coefs, new_cache0 = self.df_gru.forward(emb, hx=cache0)
726
+ if self.df_skip is not None:
727
+ df_coefs = df_coefs + self.df_skip(emb)
728
+ # df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size]
729
+
730
+ # c0 shape: [batch_size, channels, time_steps, df_bins]
731
+ c0, new_cache1 = self.df_convp.forward(c0, cache=cache1)
732
+ # c0 shape: [batch_size, df_order * 2, time_steps, df_bins]
733
+ c0 = c0.permute(0, 2, 3, 1)
734
+ # c0 shape: [batch_size, time_steps, df_bins, df_order * 2]
735
+
736
+ df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order
737
+ # df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2]
738
+ df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch)
739
+ # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
740
+ df_coefs = df_coefs + c0
741
+ # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
742
+
743
+ new_cache_dict = {
744
+ "cache0": new_cache0,
745
+ "cache1": new_cache1,
746
+ }
747
+ return df_coefs, new_cache_dict
748
+
749
+
750
+ class DfOutputReshapeMF(nn.Module):
751
+ """Coefficients output reshape for multiframe/MultiFrameModule
752
+
753
+ Requires input of shape B, C, T, F, 2.
754
+ """
755
+
756
+ def __init__(self, df_order: int, df_bins: int):
757
+ super().__init__()
758
+ self.df_order = df_order
759
+ self.df_bins = df_bins
760
+
761
+ def forward(self, coefs: torch.Tensor) -> torch.Tensor:
762
+ # [B, T, F, O*2] -> [B, O, T, F, 2]
763
+ new_shape = list(coefs.shape)
764
+ new_shape[-1] = -1
765
+ new_shape.append(2)
766
+ coefs = coefs.view(new_shape)
767
+ coefs = coefs.permute(0, 3, 1, 2, 4)
768
+ return coefs
769
+
770
+
771
+ class Mask(nn.Module):
772
+ def __init__(self, use_post_filter: bool = False, eps: float = 1e-12):
773
+ super().__init__()
774
+ self.use_post_filter = use_post_filter
775
+ self.eps = eps
776
+
777
+ def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
778
+ """
779
+ Post-Filter
780
+
781
+ A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
782
+ https://arxiv.org/abs/2008.04259
783
+
784
+ :param mask: Real valued mask, typically of shape [B, C, T, F].
785
+ :param beta: Global gain factor.
786
+ :return:
787
+ """
788
+ mask_sin = mask * torch.sin(np.pi * mask / 2)
789
+ mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
790
+ return mask_pf
791
+
792
+ def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
793
+ # spec shape: [b, 1, t, spec_bins, 2]
794
+
795
+ if not self.training and self.use_post_filter:
796
+ mask = self.post_filter(mask)
797
+
798
+ # mask shape: [b, 1, t, spec_bins]
799
+ mask = mask.unsqueeze(4)
800
+ # mask shape: [b, 1, t, spec_bins, 1]
801
+ return spec * mask
802
+
803
+
804
+ class DeepFiltering(nn.Module):
805
+ def __init__(self,
806
+ df_bins: int,
807
+ df_order: int,
808
+ lookahead: int = 0,
809
+ ):
810
+ super(DeepFiltering, self).__init__()
811
+ self.df_bins = df_bins
812
+ self.df_order = df_order
813
+ self.lookahead = lookahead
814
+
815
+ self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0)
816
+
817
+ def forward(self, *args, **kwargs):
818
+ raise AssertionError("use `forward_offline` or `forward_online` stead.")
819
+
820
+ def spec_unfold_offline(self, spec: torch.Tensor) -> torch.Tensor:
821
+ """
822
+ Pads and unfolds the spectrogram according to frame_size.
823
+ :param spec: shape: [b, c, t, f], dtype: torch.complex64
824
+ :return: shape: [b, c, t, f, df_order]
825
+ """
826
+ if self.df_order <= 1:
827
+ return spec.unsqueeze(-1)
828
+
829
+ # spec shape: [b, 1, t, f], dtype: torch.complex64
830
+ spec = self.pad(spec)
831
+ # spec_pad shape: [b, 1, t+df_order-1, f], dtype: torch.complex64
832
+ spec_unfold = spec.unfold(dimension=2, size=self.df_order, step=1)
833
+ # spec_unfold shape: [b, 1, t, f, df_order], dtype: torch.complex64
834
+ return spec_unfold
835
+
836
+ def forward_offline(self,
837
+ spec: torch.Tensor,
838
+ coefs: torch.Tensor,
839
+ ):
840
+ # spec shape: [b, 1, t, spec_bins, 2]
841
+ spec_c = torch.view_as_complex(spec.contiguous())
842
+ # spec_c shape: [b, 1, t, spec_bins]
843
+ spec_u = self.spec_unfold_offline(spec_c)
844
+ # spec_u shape: [b, 1, t, spec_bins, df_order]
845
+ spec_f = spec_u.narrow(dim=-2, start=0, length=self.df_bins)
846
+ # spec_f shape: [b, 1, t, df_bins, df_order]
847
+
848
+ # coefs shape: [b, df_order, t, df_bins, 2]
849
+ coefs = torch.view_as_complex(coefs.contiguous())
850
+ # coefs shape: [b, df_order, t, df_bins]
851
+ coefs = coefs.unsqueeze(dim=1)
852
+ # coefs shape: [b, 1, df_order, t, df_bins]
853
+
854
+ spec_f = self.df_offline(spec_f, coefs)
855
+ # spec_f shape: [b, 1, t, df_bins]
856
+
857
+ spec_f = torch.view_as_real(spec_f)
858
+ # spec_f shape: [b, 1, t, df_bins, 2]
859
+ return spec_f
860
+
861
+ def df_offline(self, spec: torch.Tensor, coefs: torch.Tensor):
862
+ """
863
+ Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
864
+ :param spec: [b, 1, t, df_bins, df_order] complex.
865
+ :param coefs: [b, 1, df_order, t, df_bins] complex.
866
+ :return: [b, 1, t, df_bins] complex.
867
+ """
868
+ spec_f = torch.einsum("...tfn,...ntf->...tf", spec, coefs)
869
+ return spec_f
870
+
871
+ def spec_unfold_online(self, spec: torch.Tensor, cache_spec: torch.Tensor = None):
872
+ """
873
+ Pads and unfolds the spectrogram according to frame_size.
874
+ :param spec: shape: [b, c, t, f], dtype: torch.complex64
875
+ :param cache_spec: shape: [b, c, df_order-1, f], dtype: torch.complex64
876
+ :return: shape: [b, c, t, f, df_order]
877
+ """
878
+ if self.df_order <= 1:
879
+ return spec.unsqueeze(-1)
880
+
881
+ if cache_spec is None:
882
+ b, c, _, f = spec.shape
883
+ cache_spec = spec.new_zeros(size=(b, c, self.df_order-1, f))
884
+ spec_pad = torch.concat(tensors=[
885
+ cache_spec, spec
886
+ ], dim=2)
887
+ new_cache_spec = spec_pad[:, :, -(self.df_order-1):, :]
888
+
889
+ # spec_pad shape: [b, 1, t+df_order-1, f], dtype: torch.complex64
890
+ spec_unfold = spec_pad.unfold(dimension=2, size=self.df_order, step=1)
891
+ # spec_unfold shape: [b, 1, t, f, df_order], dtype: torch.complex64
892
+ return spec_unfold, new_cache_spec
893
+
894
+ def forward_online(self,
895
+ spec: torch.Tensor,
896
+ coefs: torch.Tensor,
897
+ cache_dict: dict = None,
898
+ ):
899
+ if cache_dict is None:
900
+ cache_dict = defaultdict(lambda: None)
901
+ cache0 = cache_dict["cache0"]
902
+ cache1 = cache_dict["cache1"]
903
+
904
+ # spec shape: [b, 1, t, spec_bins, 2]
905
+ spec_c = torch.view_as_complex(spec.contiguous())
906
+ # spec_c shape: [b, 1, t, spec_bins]
907
+ spec_u, new_cache0 = self.spec_unfold_online(spec_c, cache_spec=cache0)
908
+ # spec_u shape: [b, 1, t, spec_bins, df_order]
909
+ spec_f = spec_u.narrow(dim=-2, start=0, length=self.df_bins)
910
+ # spec_f shape: [b, 1, t, df_bins, df_order]
911
+
912
+ # coefs shape: [b, df_order, t, df_bins, 2]
913
+ coefs = torch.view_as_complex(coefs.contiguous())
914
+ # coefs shape: [b, df_order, t, df_bins]
915
+ coefs = coefs.unsqueeze(dim=1)
916
+ # coefs shape: [b, 1, df_order, t, df_bins]
917
+
918
+ spec_f, new_cache1 = self.df_online(spec_f, coefs, cache_coefs=cache1)
919
+ # spec_f shape: [b, 1, t, df_bins]
920
+
921
+ spec_f = torch.view_as_real(spec_f)
922
+ # spec_f shape: [b, 1, t, df_bins, 2]
923
+
924
+ new_cache_dict = {
925
+ "cache0": new_cache0,
926
+ "cache1": new_cache1,
927
+ }
928
+ return spec_f, new_cache_dict
929
+
930
+ def df_online(self, spec: torch.Tensor, coefs: torch.Tensor, cache_coefs: torch.Tensor = None) -> torch.Tensor:
931
+ """
932
+ Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
933
+ :param spec: [b, 1, 1, df_bins, df_order] complex.
934
+ :param coefs: [b, 1, df_order, 1, df_bins] complex.
935
+ :param cache_coefs: [b, 1, df_order, lookahead, df_bins] complex.
936
+ :return: [b, 1, 1, df_bins] complex.
937
+ """
938
+
939
+ if cache_coefs is None:
940
+ b, c, _, _, f = coefs.shape
941
+ cache_coefs = coefs.new_zeros(size=(b, c, self.df_order, self.lookahead, f))
942
+ coefs_pad = torch.concat(tensors=[
943
+ cache_coefs, coefs
944
+ ], dim=3)
945
+
946
+ # coefs_pad shape: [b, 1, df_order, 1+lookahead, df_bins], torch.complex64.
947
+ coefs = coefs_pad[:, :, :, :-self.lookahead, :]
948
+ # coefs shape: [b, 1, df_order, 1, df_bins], torch.complex64.
949
+ new_cache_coefs = coefs_pad[:, :, :, -self.lookahead:, :]
950
+ # new_cache_coefs shape: [b, 1, df_order, lookahead, df_bins], torch.complex64.
951
+ spec_f = torch.einsum("...tfn,...ntf->...tf", spec, coefs)
952
+ return spec_f, new_cache_coefs
953
+
954
+
955
+ class DfNet2(nn.Module):
956
+ def __init__(self, config: DfNet2Config):
957
+ super(DfNet2, self).__init__()
958
+ self.config = config
959
+ self.eps = 1e-12
960
+
961
+ self.freq_bins = self.config.nfft // 2 + 1
962
+
963
+ self.nfft = config.nfft
964
+ self.win_size = config.win_size
965
+ self.hop_size = config.hop_size
966
+ self.win_type = config.win_type
967
+
968
+ self.erb_bands = ErbBands(
969
+ sample_rate=config.sample_rate,
970
+ nfft=config.nfft,
971
+ erb_bins=config.erb_bins,
972
+ min_freq_bins_for_erb=config.min_freq_bins_for_erb,
973
+ )
974
+
975
+ self.stft = ConvSTFT(
976
+ nfft=config.nfft,
977
+ win_size=config.win_size,
978
+ hop_size=config.hop_size,
979
+ win_type=config.win_type,
980
+ power=None,
981
+ requires_grad=False
982
+ )
983
+ self.istft = ConviSTFT(
984
+ nfft=config.nfft,
985
+ win_size=config.win_size,
986
+ hop_size=config.hop_size,
987
+ win_type=config.win_type,
988
+ requires_grad=False
989
+ )
990
+
991
+ self.encoder = Encoder(config)
992
+ self.erb_decoder = ErbDecoder(config)
993
+
994
+ self.df_decoder = DfDecoder(config)
995
+ self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins)
996
+ self.df_op = DeepFiltering(
997
+ df_bins=config.df_bins,
998
+ df_order=config.df_order,
999
+ lookahead=config.df_lookahead,
1000
+ )
1001
+
1002
+ self.mask = Mask(use_post_filter=config.use_post_filter)
1003
+
1004
+ self.lsnr_fn = LocalSnrTarget(
1005
+ sample_rate=config.sample_rate,
1006
+ nfft=config.nfft,
1007
+ win_size=config.win_size,
1008
+ hop_size=config.hop_size,
1009
+ n_frame=config.n_frame,
1010
+ min_local_snr=config.min_local_snr,
1011
+ max_local_snr=config.max_local_snr,
1012
+ db=True,
1013
+ )
1014
+
1015
+ def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
1016
+ if signal.dim() == 2:
1017
+ signal = torch.unsqueeze(signal, dim=1)
1018
+ _, _, n_samples = signal.shape
1019
+ remainder = (n_samples - self.win_size) % self.hop_size
1020
+ if remainder > 0:
1021
+ n_samples_pad = self.hop_size - remainder
1022
+ signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
1023
+ return signal
1024
+
1025
+ def feature_prepare(self, signal: torch.Tensor):
1026
+ # noisy shape: [b, num_samples_pad]
1027
+ spec_cmp = self.stft.forward(signal)
1028
+ # spec_complex shape: [b, f, t], torch.complex64
1029
+ spec_cmp = torch.transpose(spec_cmp, dim0=1, dim1=2)
1030
+ # spec_complex shape: [b, t, f], torch.complex64
1031
+ spec_cmp_real = torch.view_as_real(spec_cmp)
1032
+ # spec_cmp_real shape: [b, t, f, 2]
1033
+ spec_mag = torch.abs(spec_cmp)
1034
+ spec_pow = torch.square(spec_mag)
1035
+ # shape: [b, t, f]
1036
+
1037
+ spec = torch.unsqueeze(spec_cmp_real, dim=1)
1038
+ # spec shape: [b, 1, t, f, 2]
1039
+
1040
+ feat_erb = self.erb_bands.erb_scale(spec_pow, db=True)
1041
+ # feat_erb shape: [b, t, erb_bins]
1042
+ feat_erb = torch.unsqueeze(feat_erb, dim=1)
1043
+ # feat_erb shape: [b, 1, t, erb_bins]
1044
+
1045
+ feat_spec = spec_cmp_real.permute(0, 3, 1, 2)
1046
+ # feat_spec shape: [b, 2, t, f]
1047
+ feat_spec = feat_spec[..., :self.df_decoder.df_bins]
1048
+ # feat_spec shape: [b, 2, t, df_bins]
1049
+
1050
+ return spec, feat_erb, feat_spec
1051
+
1052
+ def forward(self,
1053
+ noisy: torch.Tensor,
1054
+ ):
1055
+ """
1056
+ :param noisy:
1057
+ :return:
1058
+ est_spec: shape: [b, 257*2, t]
1059
+ est_wav: shape: [b, num_samples]
1060
+ est_mask: shape: [b, 257, t]
1061
+ lsnr: shape: [b, 1, t]
1062
+ """
1063
+ n_samples = noisy.shape[-1]
1064
+ noisy = self.signal_prepare(noisy)
1065
+
1066
+ spec, feat_erb, feat_spec = self.feature_prepare(noisy)
1067
+
1068
+ e0, e1, e2, e3, emb, c0, lsnr, _ = self.encoder.forward(feat_erb, feat_spec)
1069
+
1070
+ mask, _ = self.erb_decoder.forward(emb, e3, e2, e1, e0)
1071
+ # mask shape: [b, 1, t, erb_bins]
1072
+ mask = self.erb_bands.erb_scale_inv(mask)
1073
+ # mask shape: [b, 1, t, f]
1074
+ if torch.any(mask > 1) or torch.any(mask < 0):
1075
+ raise AssertionError
1076
+
1077
+ spec_m = self.mask.forward(spec, mask)
1078
+ # spec_m shape: [b, 1, t, f, 2]
1079
+ spec_m = spec_m[:, :, :, :self.config.spec_bins, :]
1080
+ # spec_m shape: [b, 1, t, spec_bins, 2]
1081
+
1082
+ # lsnr shape: [b, t, 1]
1083
+ lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
1084
+ # lsnr shape: [b, 1, t]
1085
+
1086
+ df_coefs, _ = self.df_decoder.forward(emb, c0)
1087
+ df_coefs = self.df_out_transform(df_coefs)
1088
+ # df_coefs shape: [b, df_order, t, df_bins, 2]
1089
+
1090
+ spec_ = spec[:, :, :, :self.config.spec_bins, :]
1091
+ # spec shape: [b, 1, t, spec_bins, 2]
1092
+ spec_f = self.df_op.forward_offline(spec_, df_coefs)
1093
+ # spec_f shape: [b, 1, t, df_bins, 2], torch.float32
1094
+
1095
+ spec_e = torch.concat(tensors=[
1096
+ spec_f, spec_m[..., self.df_decoder.df_bins:, :]
1097
+ ], dim=3)
1098
+
1099
+ spec_e = torch.squeeze(spec_e, dim=1)
1100
+ spec_e = spec_e.permute(0, 2, 1, 3)
1101
+ # spec_e shape: [b, spec_bins, t, 2]
1102
+
1103
+ # spec_e shape: [b, spec_bins, t, 2]
1104
+ est_spec = torch.view_as_complex(spec_e.contiguous())
1105
+ # est_spec shape: [b, spec_bins, t], torch.complex64
1106
+ est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1)
1107
+ # est_spec shape: [b, f, t], torch.complex64
1108
+
1109
+ est_wav = self.istft.forward(est_spec)
1110
+ est_wav = est_wav[:, :, :n_samples]
1111
+ # est_wav shape: [b, 1, n_samples]
1112
+
1113
+ est_mask = torch.squeeze(mask, dim=1)
1114
+ est_mask = est_mask.permute(0, 2, 1)
1115
+ # est_mask shape: [b, f, t]
1116
+
1117
+ return est_spec, est_wav, est_mask, lsnr
1118
+
1119
+ def forward_chunk_by_chunk(self,
1120
+ noisy: torch.Tensor,
1121
+ ):
1122
+ noisy = self.signal_prepare(noisy)
1123
+ b, _, _ = noisy.shape
1124
+ noisy = torch.concat(tensors=[
1125
+ noisy, noisy.new_zeros(size=(b, 1, (self.config.df_lookahead+1)*self.hop_size))
1126
+ ], dim=2)
1127
+ b, _, num_samples = noisy.shape
1128
+
1129
+ t = (num_samples - self.win_size) // self.hop_size + 1
1130
+
1131
+ cache_dict0 = None
1132
+ cache_dict1 = None
1133
+ cache_dict2 = None
1134
+ cache_dict3 = None
1135
+ cache_dict4 = None
1136
+ cache_dict5 = None
1137
+
1138
+ waveform_list = list()
1139
+ for i in range(int(t)):
1140
+ begin = i * self.hop_size
1141
+ end = begin + self.win_size
1142
+ sub_noisy = noisy[:, :, begin: end]
1143
+
1144
+ spec, feat_erb, feat_spec = self.feature_prepare(sub_noisy)
1145
+ # spec shape: [b, 1, t, f, 2]
1146
+ # feat_erb shape: [b, 1, t, erb_bins]
1147
+ # feat_spec shape: [b, 2, t, df_bins]
1148
+
1149
+ e0, e1, e2, e3, emb, c0, lsnr, cache_dict0 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict0)
1150
+
1151
+ mask, cache_dict1 = self.erb_decoder.forward(emb, e3, e2, e1, e0, cache_dict=cache_dict1)
1152
+ # mask shape: [b, 1, t, erb_bins]
1153
+ mask = self.erb_bands.erb_scale_inv(mask)
1154
+ # mask shape: [b, 1, t, f]
1155
+
1156
+ spec_m = self.mask.forward(spec, mask)
1157
+ # spec_m shape: [b, 1, t, f, 2]
1158
+ spec_m = spec_m[:, :, :, :self.config.spec_bins, :]
1159
+ # spec_m shape: [b, 1, t, spec_bins, 2]
1160
+
1161
+ # lsnr shape: [b, t, 1]
1162
+ lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
1163
+ # lsnr shape: [b, 1, t]
1164
+
1165
+ df_coefs, cache_dict2 = self.df_decoder.forward(emb, c0, cache_dict=cache_dict2)
1166
+ df_coefs = self.df_out_transform(df_coefs)
1167
+ # df_coefs shape: [b, df_order, t, df_bins, 2]
1168
+
1169
+ spec_ = spec[:, :, :, :self.config.spec_bins, :]
1170
+ # spec shape: [b, 1, t, spec_bins, 2]
1171
+ spec_f, cache_dict3 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict3)
1172
+ # spec_f shape: [b, 1, t, df_bins, 2], torch.float32
1173
+
1174
+ spec_e = torch.concat(tensors=[
1175
+ spec_f, spec_m[..., self.df_decoder.df_bins:, :]
1176
+ ], dim=3)
1177
+
1178
+ spec_e, cache_dict4 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict4)
1179
+
1180
+ spec_e = torch.squeeze(spec_e, dim=1)
1181
+ spec_e = spec_e.permute(0, 2, 1, 3)
1182
+ # spec_e shape: [b, spec_bins, t, 2]
1183
+
1184
+ # spec_e shape: [b, spec_bins, t, 2]
1185
+ est_spec = torch.view_as_complex(spec_e.contiguous())
1186
+ # est_spec shape: [b, spec_bins, t], torch.complex64
1187
+ est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1)
1188
+ # est_spec shape: [b, f, t], torch.complex64
1189
+
1190
+ est_wav, cache_dict5 = self.istft.forward_chunk(est_spec, cache_dict=cache_dict5)
1191
+ # est_wav shape: [b, 1, hop_size]
1192
+
1193
+ waveform_list.append(est_wav)
1194
+
1195
+ waveform = torch.concat(tensors=waveform_list, dim=-1)
1196
+ # waveform shape: [b, 1, n]
1197
+ return waveform
1198
+
1199
+ def spec_e_m_combine_online(self, spec_f: torch.Tensor, spec_m: torch.Tensor, cache_dict: dict = None):
1200
+ """
1201
+ :param spec_f: shape: [b, 1, t, df_bins, 2], torch.float32
1202
+ :param spec_m: shape: [b, 1, t, spec_bins, 2]
1203
+ :param cache_dict:
1204
+ :return:
1205
+ """
1206
+ if cache_dict is None:
1207
+ cache_dict = defaultdict(lambda: None)
1208
+ cache_spec_m = cache_dict["cache_spec_m"]
1209
+
1210
+ if cache_spec_m is None:
1211
+ b, c, t, f, _ = spec_m.shape
1212
+ cache_spec_m = spec_m.new_zeros(size=(b, c, self.config.df_lookahead, f, 2))
1213
+ # cache0 shape: [b, 1, lookahead, f, 2]
1214
+ spec_m_cat = torch.concat(tensors=[
1215
+ cache_spec_m, spec_m,
1216
+ ], dim=2)
1217
+
1218
+ spec_m = spec_m_cat[:, :, :-self.config.df_lookahead, :, :]
1219
+ new_cache_spec_m = spec_m_cat[:, :, -self.config.df_lookahead:, :, :]
1220
+
1221
+ spec_e = torch.concat(tensors=[
1222
+ spec_f, spec_m[..., self.df_decoder.df_bins:, :]
1223
+ ], dim=3)
1224
+
1225
+ new_cache_dict = {
1226
+ "cache_spec_m": new_cache_spec_m,
1227
+ }
1228
+ return spec_e, new_cache_dict
1229
+
1230
+ def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
1231
+ """
1232
+ :param est_mask: torch.Tensor, shape: [b, 257, t]
1233
+ :param clean:
1234
+ :param noisy:
1235
+ :return:
1236
+ """
1237
+ if noisy.shape != clean.shape:
1238
+ raise AssertionError("Input signals must have the same shape")
1239
+ noise = noisy - clean
1240
+
1241
+ clean = self.signal_prepare(clean)
1242
+ noise = self.signal_prepare(noise)
1243
+
1244
+ stft_clean = self.stft.forward(clean)
1245
+ mag_clean = torch.abs(stft_clean)
1246
+
1247
+ stft_noise = self.stft.forward(noise)
1248
+ mag_noise = torch.abs(stft_noise)
1249
+
1250
+ gth_irm_mask = (mag_clean / (mag_clean + mag_noise + self.eps)).clamp(0, 1)
1251
+
1252
+ loss = F.l1_loss(gth_irm_mask, est_mask, reduction="mean")
1253
+
1254
+ return loss
1255
+
1256
+ def lsnr_loss_fn(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
1257
+ if noisy.shape != clean.shape:
1258
+ raise AssertionError("Input signals must have the same shape")
1259
+ noise = noisy - clean
1260
+
1261
+ clean = self.signal_prepare(clean)
1262
+ noise = self.signal_prepare(noise)
1263
+
1264
+ stft_clean = self.stft.forward(clean)
1265
+ stft_noise = self.stft.forward(noise)
1266
+ # shape: [b, f, t]
1267
+ stft_clean = torch.transpose(stft_clean, dim0=1, dim1=2)
1268
+ stft_noise = torch.transpose(stft_noise, dim0=1, dim1=2)
1269
+ # shape: [b, t, f]
1270
+ stft_clean = torch.unsqueeze(stft_clean, dim=1)
1271
+ stft_noise = torch.unsqueeze(stft_noise, dim=1)
1272
+ # shape: [b, 1, t, f]
1273
+
1274
+ # lsnr shape: [b, 1, t]
1275
+ lsnr = lsnr.squeeze(1)
1276
+ # lsnr shape: [b, t]
1277
+
1278
+ lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
1279
+ # lsnr_gth shape: [b, t]
1280
+
1281
+ loss = F.mse_loss(lsnr, lsnr_gth)
1282
+ return loss
1283
+
1284
+
1285
+ class DfNet2PretrainedModel(DfNet2):
1286
+ def __init__(self,
1287
+ config: DfNet2Config,
1288
+ ):
1289
+ super(DfNet2PretrainedModel, self).__init__(
1290
+ config=config,
1291
+ )
1292
+
1293
+ @classmethod
1294
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
1295
+ config = DfNet2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
1296
+
1297
+ model = cls(config)
1298
+
1299
+ if os.path.isdir(pretrained_model_name_or_path):
1300
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
1301
+ else:
1302
+ ckpt_file = pretrained_model_name_or_path
1303
+
1304
+ with open(ckpt_file, "rb") as f:
1305
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
1306
+ model.load_state_dict(state_dict, strict=True)
1307
+ return model
1308
+
1309
+ def save_pretrained(self,
1310
+ save_directory: Union[str, os.PathLike],
1311
+ state_dict: Optional[dict] = None,
1312
+ ):
1313
+
1314
+ model = self
1315
+
1316
+ if state_dict is None:
1317
+ state_dict = model.state_dict()
1318
+
1319
+ os.makedirs(save_directory, exist_ok=True)
1320
+
1321
+ # save state dict
1322
+ model_file = os.path.join(save_directory, MODEL_FILE)
1323
+ torch.save(state_dict, model_file)
1324
+
1325
+ # save config
1326
+ config_file = os.path.join(save_directory, CONFIG_FILE)
1327
+ self.config.to_yaml_file(config_file)
1328
+ return save_directory
1329
+
1330
+
1331
+ def main():
1332
+
1333
+ config = DfNet2Config()
1334
+ model = DfNet2PretrainedModel(config=config)
1335
+ model.eval()
1336
+
1337
+ noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
1338
+
1339
+ est_spec, est_wav, est_mask, lsnr = model.forward(noisy)
1340
+ # print(f"est_spec.shape: {est_spec.shape}")
1341
+ # print(f"est_wav.shape: {est_wav.shape}")
1342
+ # print(f"est_mask.shape: {est_mask.shape}")
1343
+ # print(f"lsnr.shape: {lsnr.shape}")
1344
+
1345
+ waveform = est_wav
1346
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
1347
+ print(waveform[:, :, 300: 302])
1348
+ print(waveform[:, :, 15680: 15682])
1349
+ print(waveform[:, :, 15760: 15762])
1350
+ print(waveform[:, :, 15840: 15842])
1351
+
1352
+ waveform = model.forward_chunk_by_chunk(noisy)
1353
+ waveform = waveform[:, :, (config.df_lookahead*config.hop_size):]
1354
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
1355
+ print(waveform[:, :, 300: 302])
1356
+ print(waveform[:, :, 15680: 15682])
1357
+ print(waveform[:, :, 15760: 15762])
1358
+ print(waveform[:, :, 15840: 15842])
1359
+
1360
+ return
1361
+
1362
+
1363
+ if __name__ == "__main__":
1364
+ main()
toolbox/torchaudio/models/dfnet2/yaml/config.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "dfnet"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ nfft: 512
6
+ win_size: 200
7
+ hop_size: 80
8
+
9
+ spec_bins: 256
10
+
11
+ # model
12
+ conv_channels: 64
13
+ conv_kernel_size_input:
14
+ - 3
15
+ - 3
16
+ conv_kernel_size_inner:
17
+ - 1
18
+ - 3
19
+ convt_kernel_size_inner:
20
+ - 1
21
+ - 3
22
+
23
+ embedding_hidden_size: 256
24
+ encoder_combine_op: "concat"
25
+
26
+ encoder_emb_skip_op: "none"
27
+ encoder_emb_linear_groups: 16
28
+ encoder_emb_hidden_size: 256
29
+
30
+ encoder_linear_groups: 32
31
+
32
+ decoder_emb_num_layers: 3
33
+ decoder_emb_skip_op: "none"
34
+ decoder_emb_linear_groups: 16
35
+ decoder_emb_hidden_size: 256
36
+
37
+ df_decoder_hidden_size: 256
38
+ df_num_layers: 2
39
+ df_order: 5
40
+ df_bins: 96
41
+ df_gru_skip: "grouped_linear"
42
+ df_decoder_linear_groups: 16
43
+ df_pathway_kernel_size_t: 5
44
+ df_lookahead: 2
45
+
46
+ # lsnr
47
+ n_frame: 3
48
+ lsnr_max: 30
49
+ lsnr_min: -15
50
+ norm_tau: 1.
51
+
52
+ # data
53
+ min_snr_db: -10
54
+ max_snr_db: 20
55
+
56
+ # train
57
+ lr: 0.001
58
+ lr_scheduler: "CosineAnnealingLR"
59
+ lr_scheduler_kwargs:
60
+ T_max: 250000
61
+ eta_min: 0.0001
62
+
63
+ max_epochs: 100
64
+ clip_grad_norm: 10.0
65
+ seed: 1234
66
+
67
+ num_workers: 8
68
+ batch_size: 64
69
+ eval_steps: 10000
70
+
71
+ # runtime
72
+ use_post_filter: true
toolbox/torchaudio/models/lstm/modeling_lstm.py CHANGED
@@ -238,14 +238,13 @@ def main():
238
  print(waveform[:, :, 300: 302])
239
 
240
  # 2
241
- waveform_cache = None
242
- coff_cache = None
243
  waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
244
  for i in range(int(t)):
245
  sub_spec = spec[:, :, i:i+1]
246
  begin = i * config.hop_size
247
  end = begin + config.win_size - config.hop_size
248
- sub_waveform, waveform_cache, coff_cache = model.istft.forward_chunk(sub_spec, waveform_cache, coff_cache)
249
  # end = begin + config.win_size
250
  # sub_waveform = model.istft.forward(sub_spec)
251
 
 
238
  print(waveform[:, :, 300: 302])
239
 
240
  # 2
241
+ cache_dict = None
 
242
  waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
243
  for i in range(int(t)):
244
  sub_spec = spec[:, :, i:i+1]
245
  begin = i * config.hop_size
246
  end = begin + config.win_size - config.hop_size
247
+ sub_waveform, cache_dict = model.istft.forward_chunk(sub_spec, cache_dict=cache_dict)
248
  # end = begin + config.win_size
249
  # sub_waveform = model.istft.forward(sub_spec)
250
 
toolbox/torchaudio/models/rnnoise/modeling_rnnoise.py CHANGED
@@ -232,8 +232,7 @@ class RNNoise(nn.Module):
232
  waveform = torch.zeros(size=(b, 1, 0), dtype=torch.float32)
233
 
234
  states = None
235
- waveform_cache = None
236
- coff_cache = None
237
 
238
  cache_list = list()
239
  for i in range(int(t)):
@@ -274,7 +273,7 @@ class RNNoise(nn.Module):
274
  mask = self.erb_bands.erb_scale_inv(mask_erb)
275
  mask = torch.transpose(mask, dim0=1, dim1=2)
276
  stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask)
277
- sub_waveform, waveform_cache, coff_cache = self.istft.forward_chunk(stft_denoise, waveform_cache, coff_cache)
278
  waveform = torch.concat(tensors=[waveform, sub_waveform], dim=-1)
279
 
280
  return waveform
 
232
  waveform = torch.zeros(size=(b, 1, 0), dtype=torch.float32)
233
 
234
  states = None
235
+ cache_dict = None
 
236
 
237
  cache_list = list()
238
  for i in range(int(t)):
 
273
  mask = self.erb_bands.erb_scale_inv(mask_erb)
274
  mask = torch.transpose(mask, dim0=1, dim1=2)
275
  stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask)
276
+ sub_waveform, cache_dict = self.istft.forward_chunk(stft_denoise, cache_dict=cache_dict)
277
  waveform = torch.concat(tensors=[waveform, sub_waveform], dim=-1)
278
 
279
  return waveform
toolbox/torchaudio/modules/conv_stft.py CHANGED
@@ -3,6 +3,7 @@
3
  """
4
  https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py
5
  """
 
6
  import numpy as np
7
  import torch
8
  import torch.nn as nn
@@ -144,15 +145,20 @@ class ConviSTFT(nn.Module):
144
  @torch.no_grad()
145
  def forward_chunk(self,
146
  spec: torch.Tensor,
147
- waveform_cache: torch.Tensor = None,
148
- coff_cache: torch.Tensor = None,
149
  ):
150
  """
151
  :param spec: shape: [b, f, t]
152
- :param waveform_cache: shape: [b, 1, win_size - hop_size]
153
- :param coff_cache: shape: [b, 1, win_size - hop_size]
 
154
  :return:
155
  """
 
 
 
 
 
156
  spec = torch.view_as_real(spec)
157
  matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1)
158
 
@@ -174,7 +180,12 @@ class ConviSTFT(nn.Module):
174
  new_coff_cache = coff_current[:, :, self.hop_size:]
175
 
176
  waveform_output = waveform_output / (coff_output + 1e-8)
177
- return waveform_output, new_waveform_cache, new_coff_cache
 
 
 
 
 
178
 
179
 
180
  def main():
@@ -238,15 +249,14 @@ def main2():
238
  print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
239
  print(waveform[:, :, 300: 302])
240
 
241
- waveform_cache = None
242
- coff_cache = None
243
  waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
244
  for i in range(int(t)):
245
  sub_spec = spec[:, :, i:i+1]
246
  begin = i * hop_size
247
 
248
  end = begin + win_size - hop_size
249
- sub_waveform, waveform_cache, coff_cache = istft.forward_chunk(sub_spec, waveform_cache, coff_cache)
250
  # end = begin + win_size
251
  # sub_waveform = istft.forward(sub_spec)
252
 
 
3
  """
4
  https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py
5
  """
6
+ from collections import defaultdict
7
  import numpy as np
8
  import torch
9
  import torch.nn as nn
 
145
  @torch.no_grad()
146
  def forward_chunk(self,
147
  spec: torch.Tensor,
148
+ cache_dict: dict = None
 
149
  ):
150
  """
151
  :param spec: shape: [b, f, t]
152
+ :param cache_dict: dict,
153
+ waveform_cache shape: [b, 1, win_size - hop_size]
154
+ coff_cache shape: [b, 1, win_size - hop_size]
155
  :return:
156
  """
157
+ if cache_dict is None:
158
+ cache_dict = defaultdict(lambda: None)
159
+ waveform_cache = cache_dict["waveform_cache"]
160
+ coff_cache = cache_dict["coff_cache"]
161
+
162
  spec = torch.view_as_real(spec)
163
  matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1)
164
 
 
180
  new_coff_cache = coff_current[:, :, self.hop_size:]
181
 
182
  waveform_output = waveform_output / (coff_output + 1e-8)
183
+
184
+ new_cache_dict = {
185
+ "waveform_cache": new_waveform_cache,
186
+ "coff_cache": new_coff_cache,
187
+ }
188
+ return waveform_output, new_cache_dict
189
 
190
 
191
  def main():
 
249
  print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
250
  print(waveform[:, :, 300: 302])
251
 
252
+ cache_dict = None
 
253
  waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
254
  for i in range(int(t)):
255
  sub_spec = spec[:, :, i:i+1]
256
  begin = i * hop_size
257
 
258
  end = begin + win_size - hop_size
259
+ sub_waveform, cache_dict = istft.forward_chunk(sub_spec, cache_dict=cache_dict)
260
  # end = begin + win_size
261
  # sub_waveform = istft.forward(sub_spec)
262