HoneyTian commited on
Commit
da78a0e
·
1 Parent(s): 1d4c9c3
examples/conv_tasnet/step_1_prepare_data.py CHANGED
@@ -107,7 +107,7 @@ def main():
107
  process_bar = tqdm(desc="build dataset excel")
108
  with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
  for noise, speech in zip(noise_generator, speech_generator):
110
- if count >= args.max_count:
111
  break
112
 
113
  noise_filename = noise["filename"]
@@ -124,6 +124,8 @@ def main():
124
  random2 = random.random()
125
 
126
  row = {
 
 
127
  "noise_filename": noise_filename,
128
  "noise_raw_duration": noise_raw_duration,
129
  "noise_offset": noise_offset,
 
107
  process_bar = tqdm(desc="build dataset excel")
108
  with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
  for noise, speech in zip(noise_generator, speech_generator):
110
+ if count >= args.max_count > 0:
111
  break
112
 
113
  noise_filename = noise["filename"]
 
124
  random2 = random.random()
125
 
126
  row = {
127
+ "count": count,
128
+
129
  "noise_filename": noise_filename,
130
  "noise_raw_duration": noise_raw_duration,
131
  "noise_offset": noise_offset,
examples/dfnet/run.sh ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn \
7
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
9
+
10
+
11
+ END
12
+
13
+
14
+ # params
15
+ system_version="windows";
16
+ verbose=true;
17
+ stage=0 # start from 0 if you need to start from data preparation
18
+ stop_stage=9
19
+
20
+ work_dir="$(pwd)"
21
+ file_folder_name=file_folder_name
22
+ final_model_name=final_model_name
23
+ config_file="yaml/config.yaml"
24
+ limit=10
25
+
26
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
27
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
28
+
29
+ max_count=10000000
30
+
31
+ nohup_name=nohup.out
32
+
33
+ # model params
34
+ batch_size=64
35
+ max_epochs=200
36
+ save_top_k=10
37
+ patience=5
38
+
39
+
40
+ # parse options
41
+ while true; do
42
+ [ -z "${1:-}" ] && break; # break if there are no arguments
43
+ case "$1" in
44
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
45
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
46
+ old_value="(eval echo \\$$name)";
47
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
48
+ was_bool=true;
49
+ else
50
+ was_bool=false;
51
+ fi
52
+
53
+ # Set the variable to the right value-- the escaped quotes make it work if
54
+ # the option had spaces, like --cmd "queue.pl -sync y"
55
+ eval "${name}=\"$2\"";
56
+
57
+ # Check that Boolean-valued arguments are really Boolean.
58
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
59
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
60
+ exit 1;
61
+ fi
62
+ shift 2;
63
+ ;;
64
+
65
+ *) break;
66
+ esac
67
+ done
68
+
69
+ file_dir="${work_dir}/${file_folder_name}"
70
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
71
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
72
+
73
+ train_dataset="${file_dir}/train.jsonl"
74
+ valid_dataset="${file_dir}/valid.jsonl"
75
+
76
+ $verbose && echo "system_version: ${system_version}"
77
+ $verbose && echo "file_folder_name: ${file_folder_name}"
78
+
79
+ if [ $system_version == "windows" ]; then
80
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
81
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
82
+ #source /data/local/bin/nx_denoise/bin/activate
83
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
84
+ fi
85
+
86
+
87
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
88
+ $verbose && echo "stage 1: prepare data"
89
+ cd "${work_dir}" || exit 1
90
+ python3 step_1_prepare_data.py \
91
+ --file_dir "${file_dir}" \
92
+ --noise_dir "${noise_dir}" \
93
+ --speech_dir "${speech_dir}" \
94
+ --train_dataset "${train_dataset}" \
95
+ --valid_dataset "${valid_dataset}" \
96
+ --max_count "${max_count}" \
97
+
98
+ fi
99
+
100
+
101
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
102
+ $verbose && echo "stage 2: train model"
103
+ cd "${work_dir}" || exit 1
104
+ python3 step_2_train_model.py \
105
+ --train_dataset "${train_dataset}" \
106
+ --valid_dataset "${valid_dataset}" \
107
+ --serialization_dir "${file_dir}" \
108
+ --config_file "${config_file}" \
109
+
110
+ fi
111
+
112
+
113
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
114
+ $verbose && echo "stage 3: test model"
115
+ cd "${work_dir}" || exit 1
116
+ python3 step_3_evaluation.py \
117
+ --valid_dataset "${valid_dataset}" \
118
+ --model_dir "${file_dir}/best" \
119
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
120
+ --limit "${limit}" \
121
+
122
+ fi
123
+
124
+
125
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
126
+ $verbose && echo "stage 4: collect files"
127
+ cd "${work_dir}" || exit 1
128
+
129
+ mkdir -p ${final_model_dir}
130
+
131
+ cp "${file_dir}/best"/* "${final_model_dir}"
132
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
133
+
134
+ cd "${final_model_dir}/.." || exit 1;
135
+
136
+ if [ -e "${final_model_name}.zip" ]; then
137
+ rm -rf "${final_model_name}_backup.zip"
138
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
139
+ fi
140
+
141
+ zip -r "${final_model_name}.zip" "${final_model_name}"
142
+ rm -rf "${final_model_name}"
143
+
144
+ fi
145
+
146
+
147
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
148
+ $verbose && echo "stage 5: clear file_dir"
149
+ cd "${work_dir}" || exit 1
150
+
151
+ rm -rf "${file_dir}";
152
+
153
+ fi
examples/dfnet/step_1_prepare_data.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--file_dir", default="./", type=str)
21
+
22
+ parser.add_argument(
23
+ "--noise_dir",
24
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--speech_dir",
29
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
30
+ type=str
31
+ )
32
+
33
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
34
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
35
+
36
+ parser.add_argument("--duration", default=4.0, type=float)
37
+ parser.add_argument("--min_snr_db", default=-10, type=float)
38
+ parser.add_argument("--max_snr_db", default=20, type=float)
39
+
40
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
41
+
42
+ parser.add_argument("--max_count", default=10000, type=int)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def filename_generator(data_dir: str):
49
+ data_dir = Path(data_dir)
50
+ for filename in data_dir.glob("**/*.wav"):
51
+ yield filename.as_posix()
52
+
53
+
54
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
55
+ data_dir = Path(data_dir)
56
+ for epoch_idx in range(max_epoch):
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ if np.sum(signal[begin: begin+win_size]) == 0:
71
+ continue
72
+ row = {
73
+ "epoch_idx": epoch_idx,
74
+ "filename": filename.as_posix(),
75
+ "raw_duration": round(raw_duration, 4),
76
+ "offset": round(begin / sample_rate, 4),
77
+ "duration": round(duration, 4),
78
+ }
79
+ yield row
80
+
81
+
82
+ def main():
83
+ args = get_args()
84
+
85
+ file_dir = Path(args.file_dir)
86
+ file_dir.mkdir(exist_ok=True)
87
+
88
+ noise_dir = Path(args.noise_dir)
89
+ speech_dir = Path(args.speech_dir)
90
+
91
+ noise_generator = target_second_signal_generator(
92
+ noise_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate,
95
+ max_epoch=100000,
96
+ )
97
+ speech_generator = target_second_signal_generator(
98
+ speech_dir.as_posix(),
99
+ duration=args.duration,
100
+ sample_rate=args.target_sample_rate,
101
+ max_epoch=1,
102
+ )
103
+
104
+ dataset = list()
105
+
106
+ count = 0
107
+ process_bar = tqdm(desc="build dataset excel")
108
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
+ for noise, speech in zip(noise_generator, speech_generator):
110
+ if count >= args.max_count > 0:
111
+ break
112
+
113
+ noise_filename = noise["filename"]
114
+ noise_raw_duration = noise["raw_duration"]
115
+ noise_offset = noise["offset"]
116
+ noise_duration = noise["duration"]
117
+
118
+ speech_filename = speech["filename"]
119
+ speech_raw_duration = speech["raw_duration"]
120
+ speech_offset = speech["offset"]
121
+ speech_duration = speech["duration"]
122
+
123
+ random1 = random.random()
124
+ random2 = random.random()
125
+
126
+ row = {
127
+ "count": count,
128
+
129
+ "noise_filename": noise_filename,
130
+ "noise_raw_duration": noise_raw_duration,
131
+ "noise_offset": noise_offset,
132
+ "noise_duration": noise_duration,
133
+
134
+ "speech_filename": speech_filename,
135
+ "speech_raw_duration": speech_raw_duration,
136
+ "speech_offset": speech_offset,
137
+ "speech_duration": speech_duration,
138
+
139
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
140
+
141
+ "random1": random1,
142
+ }
143
+ row = json.dumps(row, ensure_ascii=False)
144
+ if random2 < (1 / 300 / 1):
145
+ fvalid.write(f"{row}\n")
146
+ else:
147
+ ftrain.write(f"{row}\n")
148
+
149
+ count += 1
150
+ duration_seconds = count * args.duration
151
+ duration_hours = duration_seconds / 3600
152
+
153
+ process_bar.update(n=1)
154
+ process_bar.set_postfix({
155
+ # "duration_seconds": round(duration_seconds, 4),
156
+ "duration_hours": round(duration_hours, 4),
157
+
158
+ })
159
+
160
+ return
161
+
162
+
163
+ if __name__ == "__main__":
164
+ main()
examples/dfnet/step_2_train_model.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import logging
6
+ from logging.handlers import TimedRotatingFileHandler
7
+ import os
8
+ import platform
9
+ from pathlib import Path
10
+ import random
11
+ import sys
12
+ import shutil
13
+ from typing import List
14
+
15
+ pwd = os.path.abspath(os.path.dirname(__file__))
16
+ sys.path.append(os.path.join(pwd, "../../"))
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn import functional as F
22
+ from torch.utils.data.dataloader import DataLoader
23
+ from tqdm import tqdm
24
+
25
+ from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
26
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
27
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
28
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
29
+ from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
30
+ from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel
31
+
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
36
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
37
+
38
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
39
+ parser.add_argument("--patience", default=5, type=int)
40
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
41
+
42
+ parser.add_argument("--config_file", default="config.yaml", type=str)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def logging_config(file_dir: str):
49
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
50
+
51
+ logging.basicConfig(format=fmt,
52
+ datefmt="%m/%d/%Y %H:%M:%S",
53
+ level=logging.INFO)
54
+ file_handler = TimedRotatingFileHandler(
55
+ filename=os.path.join(file_dir, "main.log"),
56
+ encoding="utf-8",
57
+ when="D",
58
+ interval=1,
59
+ backupCount=7
60
+ )
61
+ file_handler.setLevel(logging.INFO)
62
+ file_handler.setFormatter(logging.Formatter(fmt))
63
+ logger = logging.getLogger(__name__)
64
+ logger.addHandler(file_handler)
65
+
66
+ return logger
67
+
68
+
69
+ class CollateFunction(object):
70
+ def __init__(self):
71
+ pass
72
+
73
+ def __call__(self, batch: List[dict]):
74
+ clean_audios = list()
75
+ noisy_audios = list()
76
+ snr_db_list = list()
77
+
78
+ for sample in batch:
79
+ # noise_wave: torch.Tensor = sample["noise_wave"]
80
+ clean_audio: torch.Tensor = sample["speech_wave"]
81
+ noisy_audio: torch.Tensor = sample["mix_wave"]
82
+ snr_db: float = sample["snr_db"]
83
+
84
+ clean_audios.append(clean_audio)
85
+ noisy_audios.append(noisy_audio)
86
+ snr_db_list.append(snr_db)
87
+
88
+ clean_audios = torch.stack(clean_audios)
89
+ noisy_audios = torch.stack(noisy_audios)
90
+ snr_db_list = torch.stack(snr_db_list)
91
+
92
+ # assert
93
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
94
+ raise AssertionError("nan or inf in clean_audios")
95
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
96
+ raise AssertionError("nan or inf in noisy_audios")
97
+ return clean_audios, noisy_audios, snr_db_list
98
+
99
+
100
+ collate_fn = CollateFunction()
101
+
102
+
103
+ def main():
104
+ args = get_args()
105
+
106
+ config = DfNetConfig.from_pretrained(
107
+ pretrained_model_name_or_path=args.config_file,
108
+ )
109
+
110
+ serialization_dir = Path(args.serialization_dir)
111
+ serialization_dir.mkdir(parents=True, exist_ok=True)
112
+
113
+ logger = logging_config(serialization_dir)
114
+
115
+ random.seed(config.seed)
116
+ np.random.seed(config.seed)
117
+ torch.manual_seed(config.seed)
118
+ logger.info(f"set seed: {config.seed}")
119
+
120
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
121
+ n_gpu = torch.cuda.device_count()
122
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
123
+
124
+ # datasets
125
+ train_dataset = DenoiseJsonlDataset(
126
+ jsonl_file=args.train_dataset,
127
+ expected_sample_rate=config.sample_rate,
128
+ max_wave_value=32768.0,
129
+ min_snr_db=config.min_snr_db,
130
+ max_snr_db=config.max_snr_db,
131
+ # skip=225000,
132
+ )
133
+ valid_dataset = DenoiseJsonlDataset(
134
+ jsonl_file=args.valid_dataset,
135
+ expected_sample_rate=config.sample_rate,
136
+ max_wave_value=32768.0,
137
+ min_snr_db=config.min_snr_db,
138
+ max_snr_db=config.max_snr_db,
139
+ )
140
+ train_data_loader = DataLoader(
141
+ dataset=train_dataset,
142
+ batch_size=config.batch_size,
143
+ # shuffle=True,
144
+ sampler=None,
145
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
146
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
147
+ collate_fn=collate_fn,
148
+ pin_memory=False,
149
+ prefetch_factor=2,
150
+ )
151
+ valid_data_loader = DataLoader(
152
+ dataset=valid_dataset,
153
+ batch_size=config.batch_size,
154
+ # shuffle=True,
155
+ sampler=None,
156
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
157
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
158
+ collate_fn=collate_fn,
159
+ pin_memory=False,
160
+ prefetch_factor=2,
161
+ )
162
+
163
+ # models
164
+ logger.info(f"prepare models. config_file: {args.config_file}")
165
+ model = DfNetPretrainedModel(config).to(device)
166
+ model.to(device)
167
+ model.train()
168
+
169
+ # optimizer
170
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
171
+ optimizer = torch.optim.AdamW(model.named_parameters(), config.lr)
172
+
173
+ # resume training
174
+ last_step_idx = -1
175
+ last_epoch = -1
176
+ for step_idx_str in serialization_dir.glob("steps-*"):
177
+ step_idx_str = Path(step_idx_str)
178
+ step_idx = step_idx_str.stem.split("-")[1]
179
+ step_idx = int(step_idx)
180
+ if step_idx > last_step_idx:
181
+ last_step_idx = step_idx
182
+ # last_epoch = 1
183
+
184
+ if last_step_idx != -1:
185
+ logger.info(f"resume from steps-{last_step_idx}.")
186
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
187
+ optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
188
+
189
+ logger.info(f"load state dict for model.")
190
+ with open(model_pt.as_posix(), "rb") as f:
191
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
192
+ model.load_state_dict(state_dict, strict=True)
193
+
194
+ logger.info(f"load state dict for optimizer.")
195
+ with open(optimizer_pth.as_posix(), "rb") as f:
196
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
197
+ optimizer.load_state_dict(state_dict)
198
+
199
+ if config.lr_scheduler == "CosineAnnealingLR":
200
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
201
+ optimizer,
202
+ last_epoch=last_epoch,
203
+ # T_max=10 * config.eval_steps,
204
+ # eta_min=0.01 * config.lr,
205
+ **config.lr_scheduler_kwargs,
206
+ )
207
+ elif config.lr_scheduler == "MultiStepLR":
208
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
209
+ optimizer,
210
+ last_epoch=last_epoch,
211
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
212
+ )
213
+ else:
214
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
215
+
216
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
217
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
218
+ fft_size_list=[256, 512, 1024],
219
+ win_size_list=[256, 512, 1024],
220
+ hop_size_list=[128, 256, 512],
221
+ factor_sc=1.5,
222
+ factor_mag=1.0,
223
+ reduction="mean"
224
+ ).to(device)
225
+ lsnr_loss_fn = nn.L1Loss(reduction="mean")
226
+
227
+ # training loop
228
+
229
+ # state
230
+ average_pesq_score = 1000000000
231
+ average_loss = 1000000000
232
+ average_neg_si_snr_loss = 1000000000
233
+ average_mask_loss = 1000000000
234
+
235
+ model_list = list()
236
+ best_epoch_idx = None
237
+ best_step_idx = None
238
+ best_metric = None
239
+ patience_count = 0
240
+
241
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
242
+
243
+ logger.info("training")
244
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
245
+ # train
246
+ model.train()
247
+
248
+ total_pesq_score = 0.
249
+ total_loss = 0.
250
+ total_neg_si_snr_loss = 0.
251
+ total_mask_loss = 0.
252
+ total_batches = 0.
253
+
254
+ progress_bar_train = tqdm(
255
+ initial=step_idx,
256
+ desc="Training; epoch-{}".format(epoch_idx),
257
+ )
258
+ for train_batch in train_data_loader:
259
+ clean_audios, noisy_audios, snr_db_list = train_batch
260
+ clean_audios: torch.Tensor = clean_audios.to(device)
261
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
262
+ snr_db_list: torch.Tensor = snr_db_list.to(device)
263
+
264
+ est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
265
+
266
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
267
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
268
+ # mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
269
+ # neg_si_snr_loss = lsnr_loss_fn.forward(lsnr, snr_db_list)
270
+
271
+ loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss
272
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
273
+ logger.info(f"find nan or inf in loss.")
274
+ continue
275
+
276
+ denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
277
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
278
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
279
+
280
+ optimizer.zero_grad()
281
+ loss.backward()
282
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
283
+ optimizer.step()
284
+ lr_scheduler.step()
285
+
286
+ total_pesq_score += pesq_score
287
+ total_loss += loss.item()
288
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
289
+ total_mask_loss += mask_loss.item()
290
+ total_batches += 1
291
+
292
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
293
+ average_loss = round(total_loss / total_batches, 4)
294
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
295
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
296
+
297
+ progress_bar_train.update(1)
298
+ progress_bar_train.set_postfix({
299
+ "lr": lr_scheduler.get_last_lr()[0],
300
+ "pesq_score": average_pesq_score,
301
+ "loss": average_loss,
302
+ "neg_si_snr_loss": average_neg_si_snr_loss,
303
+ "mask_loss": average_mask_loss,
304
+ })
305
+
306
+ # evaluation
307
+ step_idx += 1
308
+ if step_idx % config.eval_steps == 0:
309
+ with torch.no_grad():
310
+ torch.cuda.empty_cache()
311
+
312
+ total_pesq_score = 0.
313
+ total_loss = 0.
314
+ total_neg_si_snr_loss = 0.
315
+ total_mask_loss = 0.
316
+ total_batches = 0.
317
+
318
+ progress_bar_train.close()
319
+ progress_bar_eval = tqdm(
320
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
321
+ )
322
+ for eval_batch in valid_data_loader:
323
+ clean_audios, noisy_audios, snr_db_list = eval_batch
324
+ clean_audios: torch.Tensor = clean_audios.to(device)
325
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
326
+ snr_db_list: torch.Tensor = snr_db_list.to(device)
327
+
328
+ est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
329
+
330
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
331
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
332
+
333
+ loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss
334
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
335
+ logger.info(f"find nan or inf in loss.")
336
+ continue
337
+
338
+ denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
339
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
340
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
341
+
342
+ total_pesq_score += pesq_score
343
+ total_loss += loss.item()
344
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
345
+ total_mask_loss += mask_loss.item()
346
+ total_batches += 1
347
+
348
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
349
+ average_loss = round(total_loss / total_batches, 4)
350
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
351
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
352
+
353
+ progress_bar_eval.update(1)
354
+ progress_bar_eval.set_postfix({
355
+ "lr": lr_scheduler.get_last_lr()[0],
356
+ "pesq_score": average_pesq_score,
357
+ "loss": average_loss,
358
+ "neg_si_snr_loss": average_neg_si_snr_loss,
359
+ "mask_loss": average_mask_loss,
360
+ })
361
+
362
+ total_pesq_score = 0.
363
+ total_loss = 0.
364
+ total_neg_si_snr_loss = 0.
365
+ total_mask_loss = 0.
366
+ total_batches = 0.
367
+
368
+ progress_bar_eval.close()
369
+ progress_bar_train = tqdm(
370
+ initial=progress_bar_train.n,
371
+ postfix=progress_bar_train.postfix,
372
+ desc=progress_bar_train.desc,
373
+ )
374
+
375
+ # save path
376
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
377
+ save_dir.mkdir(parents=True, exist_ok=False)
378
+
379
+ # save models
380
+ model.save_pretrained(save_dir.as_posix())
381
+
382
+ model_list.append(save_dir)
383
+ if len(model_list) >= args.num_serialized_models_to_keep:
384
+ model_to_delete: Path = model_list.pop(0)
385
+ shutil.rmtree(model_to_delete.as_posix())
386
+
387
+ # save optim
388
+ torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
389
+
390
+ # save metric
391
+ if best_metric is None:
392
+ best_epoch_idx = epoch_idx
393
+ best_step_idx = step_idx
394
+ best_metric = average_pesq_score
395
+ elif average_pesq_score > best_metric:
396
+ # great is better.
397
+ best_epoch_idx = epoch_idx
398
+ best_step_idx = step_idx
399
+ best_metric = average_pesq_score
400
+ else:
401
+ pass
402
+
403
+ metrics = {
404
+ "epoch_idx": epoch_idx,
405
+ "best_epoch_idx": best_epoch_idx,
406
+ "best_step_idx": best_step_idx,
407
+ "pesq_score": average_pesq_score,
408
+ "loss": average_loss,
409
+ "neg_si_snr_loss": average_neg_si_snr_loss,
410
+ "mask_loss": average_mask_loss,
411
+ }
412
+ metrics_filename = save_dir / "metrics_epoch.json"
413
+ with open(metrics_filename, "w", encoding="utf-8") as f:
414
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
415
+
416
+ # save best
417
+ best_dir = serialization_dir / "best"
418
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
419
+ if best_dir.exists():
420
+ shutil.rmtree(best_dir)
421
+ shutil.copytree(save_dir, best_dir)
422
+
423
+ # early stop
424
+ early_stop_flag = False
425
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
426
+ patience_count = 0
427
+ else:
428
+ patience_count += 1
429
+ if patience_count >= args.patience:
430
+ early_stop_flag = True
431
+
432
+ # early stop
433
+ if early_stop_flag:
434
+ break
435
+
436
+ return
437
+
438
+
439
+ if __name__ == "__main__":
440
+ main()
examples/dfnet/yaml/config.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "dfnet"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+
9
+ spec_bins: 256
10
+
11
+ # model
12
+ conv_channels: 64
13
+ conv_kernel_size_input:
14
+ - 3
15
+ - 3
16
+ conv_kernel_size_inner:
17
+ - 1
18
+ - 3
19
+ conv_lookahead: 0
20
+
21
+ convt_kernel_size_inner:
22
+ - 1
23
+ - 3
24
+
25
+ embedding_hidden_size: 256
26
+ encoder_combine_op: "concat"
27
+
28
+ encoder_emb_skip_op: "none"
29
+ encoder_emb_linear_groups: 16
30
+ encoder_emb_hidden_size: 256
31
+
32
+ encoder_linear_groups: 32
33
+
34
+ lsnr_max: 30
35
+ lsnr_min: -15
36
+ norm_tau: 1.
37
+
38
+ decoder_emb_num_layers: 3
39
+ decoder_emb_skip_op: "none"
40
+ decoder_emb_linear_groups: 16
41
+ decoder_emb_hidden_size: 256
42
+
43
+ df_decoder_hidden_size: 256
44
+ df_num_layers: 2
45
+ df_order: 5
46
+ df_bins: 96
47
+ df_gru_skip: "grouped_linear"
48
+ df_decoder_linear_groups: 16
49
+ df_pathway_kernel_size_t: 5
50
+ df_lookahead: 2
51
+
52
+ # runtime
53
+ use_post_filter: true
examples/frcrn/step_1_prepare_data.py CHANGED
@@ -39,7 +39,7 @@ def get_args():
39
 
40
  parser.add_argument("--target_sample_rate", default=8000, type=int)
41
 
42
- parser.add_argument("--max_count", default=10000, type=int)
43
 
44
  args = parser.parse_args()
45
  return args
@@ -107,8 +107,9 @@ def main():
107
  process_bar = tqdm(desc="build dataset excel")
108
  with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
  for noise, speech in zip(noise_generator, speech_generator):
110
- if count >= args.max_count:
111
- break
 
112
 
113
  noise_filename = noise["filename"]
114
  noise_raw_duration = noise["raw_duration"]
@@ -124,6 +125,8 @@ def main():
124
  random2 = random.random()
125
 
126
  row = {
 
 
127
  "noise_filename": noise_filename,
128
  "noise_raw_duration": noise_raw_duration,
129
  "noise_offset": noise_offset,
 
39
 
40
  parser.add_argument("--target_sample_rate", default=8000, type=int)
41
 
42
+ parser.add_argument("--scale", default=1, type=float)
43
 
44
  args = parser.parse_args()
45
  return args
 
107
  process_bar = tqdm(desc="build dataset excel")
108
  with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
  for noise, speech in zip(noise_generator, speech_generator):
110
+ flag = random.random()
111
+ if flag > args.scale:
112
+ continue
113
 
114
  noise_filename = noise["filename"]
115
  noise_raw_duration = noise["raw_duration"]
 
125
  random2 = random.random()
126
 
127
  row = {
128
+ "count": count,
129
+
130
  "noise_filename": noise_filename,
131
  "noise_raw_duration": noise_raw_duration,
132
  "noise_offset": noise_offset,
examples/mpnet/step_1_prepare_data.py CHANGED
@@ -119,6 +119,8 @@ def get_dataset(args):
119
  random2 = random.random()
120
 
121
  row = {
 
 
122
  "noise_filename": noise_filename,
123
  "noise_raw_duration": noise_raw_duration,
124
  "noise_offset": noise_offset,
 
119
  random2 = random.random()
120
 
121
  row = {
122
+ "count": count,
123
+
124
  "noise_filename": noise_filename,
125
  "noise_raw_duration": noise_raw_duration,
126
  "noise_offset": noise_offset,
toolbox/torch/utils/data/dataset/denoise_jsonl_dataset.py CHANGED
@@ -35,6 +35,8 @@ class DenoiseJsonlDataset(IterableDataset):
35
  self.buffer_samples: List[dict] = list()
36
 
37
  def __iter__(self):
 
 
38
  iterable_source = self.iterable_source()
39
 
40
  try:
 
35
  self.buffer_samples: List[dict] = list()
36
 
37
  def __iter__(self):
38
+ self.buffer_samples = list()
39
+
40
  iterable_source = self.iterable_source()
41
 
42
  try:
toolbox/torchaudio/models/dfnet/modeling_dfnet.py CHANGED
@@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
7
  import numpy as np
8
  import torch
9
  import torch.nn as nn
 
10
  import torchaudio
11
 
12
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
@@ -752,11 +753,11 @@ class DeepFiltering(nn.Module):
752
  coefs: torch.Tensor,
753
  ):
754
  # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
755
- spec_u = self.spec_unfold(torch.view_as_complex(spec))
756
  # spec_u shape: [batch_size, 1, time_steps, spec_bins, df_order]
757
 
758
  # coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
759
- coefs = torch.view_as_complex(coefs)
760
  # coefs shape: [batch_size, df_order, time_steps, df_bins]
761
  spec_f = spec_u.narrow(-2, 0, self.df_bins)
762
  # spec_f shape: [batch_size, 1, time_steps, df_bins, df_order]
@@ -789,6 +790,13 @@ class DfNet(nn.Module):
789
  super(DfNet, self).__init__()
790
  self.config = config
791
 
 
 
 
 
 
 
 
792
  self.stft = ConvSTFT(
793
  nfft=config.nfft,
794
  win_size=config.win_size,
@@ -820,32 +828,41 @@ class DfNet(nn.Module):
820
  self.mask = Mask(use_post_filter=config.use_post_filter)
821
 
822
  def forward(self,
823
- spec_complex: torch.Tensor,
824
  ):
825
- feat_power = torch.square(torch.abs(spec_complex))
826
- feat_power = feat_power.unsqueeze(1).permute(0, 1, 3, 2)
827
- # feat_power shape: [batch_size, spec_bins, time_steps]
828
- # feat_power shape: [batch_size, 1, spec_bins, time_steps]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
829
  # feat_power shape: [batch_size, 1, time_steps, spec_bins]
830
- feat_power = feat_power.detach()
831
 
832
- # spec shape: [batch_size, spec_bins, time_steps]
833
- feat_spec = torch.view_as_real(spec_complex)
834
- # spec shape: [batch_size, spec_bins, time_steps, 2]
835
- feat_spec = feat_spec.permute(0, 3, 2, 1)
836
- # feat_spec shape: [batch_size, 2, time_steps, spec_bins]
837
  feat_spec = feat_spec[..., :self.df_decoder.df_bins]
838
  # feat_spec shape: [batch_size, 2, time_steps, df_bins]
839
- feat_spec = feat_spec.detach()
840
-
841
- # spec shape: [batch_size, spec_bins, time_steps]
842
- spec = torch.unsqueeze(spec_complex, dim=1)
843
- # spec shape: [batch_size, 1, spec_bins, time_steps]
844
- spec = spec.permute(0, 1, 3, 2)
845
- # spec shape: [batch_size, 1, time_steps, spec_bins]
846
- spec = torch.view_as_real(spec)
847
- # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
848
- spec = spec.detach()
849
 
850
  e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
851
 
@@ -865,7 +882,7 @@ class DfNet(nn.Module):
865
  # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
866
 
867
  spec_e = self.df_op.forward(spec.clone(), df_coefs)
868
- # spec_e shape: [batch_size, 1, time_steps, spec_bins, 2]
869
 
870
  spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
871
 
@@ -874,10 +891,68 @@ class DfNet(nn.Module):
874
  # spec_e shape: [batch_size, spec_bins, time_steps, 2]
875
 
876
  mask = torch.squeeze(mask, dim=1)
877
- mask = mask.permute(0, 2, 1)
878
  # mask shape: [batch_size, spec_bins, time_steps]
879
 
880
- return spec_e, mask, lsnr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
881
 
882
 
883
  class DfNetPretrainedModel(DfNet):
@@ -928,22 +1003,12 @@ class DfNetPretrainedModel(DfNet):
928
 
929
  def main():
930
 
931
- transformer = torchaudio.transforms.Spectrogram(
932
- n_fft=512,
933
- win_length=200,
934
- hop_length=80,
935
- window_fn=torch.hamming_window,
936
- power=None,
937
- )
938
-
939
  config = DfNetConfig()
940
  model = DfNetPretrainedModel(config=config)
941
 
942
- inputs = torch.randn(size=(1, 16000), dtype=torch.float32)
943
- spec_complex = transformer.forward(inputs)
944
- spec_complex = spec_complex[:, :-1, :]
945
 
946
- output = model.forward(spec_complex)
947
  print(output[1].shape)
948
  return
949
 
 
7
  import numpy as np
8
  import torch
9
  import torch.nn as nn
10
+ from torch.nn import functional as F
11
  import torchaudio
12
 
13
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
 
753
  coefs: torch.Tensor,
754
  ):
755
  # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
756
+ spec_u = self.spec_unfold(torch.view_as_complex(spec.contiguous()))
757
  # spec_u shape: [batch_size, 1, time_steps, spec_bins, df_order]
758
 
759
  # coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
760
+ coefs = torch.view_as_complex(coefs.contiguous())
761
  # coefs shape: [batch_size, df_order, time_steps, df_bins]
762
  spec_f = spec_u.narrow(-2, 0, self.df_bins)
763
  # spec_f shape: [batch_size, 1, time_steps, df_bins, df_order]
 
790
  super(DfNet, self).__init__()
791
  self.config = config
792
 
793
+ self.freq_bins = self.config.nfft // 2 + 1
794
+
795
+ self.nfft = config.nfft
796
+ self.win_size = config.win_size
797
+ self.hop_size = config.hop_size
798
+ self.win_type = config.win_type
799
+
800
  self.stft = ConvSTFT(
801
  nfft=config.nfft,
802
  win_size=config.win_size,
 
828
  self.mask = Mask(use_post_filter=config.use_post_filter)
829
 
830
  def forward(self,
831
+ noisy: torch.Tensor,
832
  ):
833
+ if noisy.dim() == 2:
834
+ noisy = torch.unsqueeze(noisy, dim=1)
835
+ _, _, n_samples = noisy.shape
836
+ remainder = (n_samples - self.win_size) % self.hop_size
837
+ if remainder > 0:
838
+ n_samples_pad = self.hop_size - remainder
839
+ noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0)
840
+
841
+ # [batch_size, freq_bins * 2, time_steps]
842
+ cmp_spec = self.stft.forward(noisy)
843
+ # [batch_size, 1, freq_bins * 2, time_steps]
844
+ cmp_spec = torch.unsqueeze(cmp_spec, 1)
845
+
846
+ # [batch_size, 2, freq_bins, time_steps]
847
+ cmp_spec = torch.cat([
848
+ cmp_spec[:, :, :self.freq_bins, :],
849
+ cmp_spec[:, :, self.freq_bins:, :],
850
+ ], dim=1)
851
+ # n//2+1 -> n//2; 257 -> 256
852
+ cmp_spec = cmp_spec[:, :, :-1, :]
853
+
854
+ spec = torch.unsqueeze(cmp_spec, dim=4)
855
+ # [batch_size, 2, freq_bins, time_steps, 1]
856
+ spec = spec.permute(0, 4, 3, 2, 1)
857
+ # spec shape: [batch_size, 1, time_steps, freq_bins, 2]
858
+
859
+ feat_power = torch.sum(torch.square(spec), dim=-1)
860
  # feat_power shape: [batch_size, 1, time_steps, spec_bins]
 
861
 
862
+ feat_spec = torch.transpose(cmp_spec, dim0=2, dim1=3)
863
+ # feat_spec shape: [batch_size, 2, time_steps, freq_bins]
 
 
 
864
  feat_spec = feat_spec[..., :self.df_decoder.df_bins]
865
  # feat_spec shape: [batch_size, 2, time_steps, df_bins]
 
 
 
 
 
 
 
 
 
 
866
 
867
  e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
868
 
 
882
  # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
883
 
884
  spec_e = self.df_op.forward(spec.clone(), df_coefs)
885
+ # est_spec shape: [batch_size, 1, time_steps, spec_bins, 2]
886
 
887
  spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
888
 
 
891
  # spec_e shape: [batch_size, spec_bins, time_steps, 2]
892
 
893
  mask = torch.squeeze(mask, dim=1)
894
+ est_mask = mask.permute(0, 2, 1)
895
  # mask shape: [batch_size, spec_bins, time_steps]
896
 
897
+ b, _, t, _ = spec_e.shape
898
+ est_spec = torch.cat(tensors=[
899
+ torch.concat(tensors=[
900
+ spec_e[..., 0],
901
+ torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
902
+ ], dim=1),
903
+ torch.concat(tensors=[
904
+ spec_e[..., 1],
905
+ torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
906
+ ], dim=1),
907
+ ], dim=1)
908
+ # est_spec shape: [b, n+2, t]
909
+ est_wav = self.istft.forward(est_spec)
910
+ est_wav = torch.squeeze(est_wav, dim=1)
911
+ est_wav = est_wav[:, :n_samples]
912
+ # est_wav shape: [b, n_samples]
913
+ return est_spec, est_wav, est_mask, lsnr
914
+
915
+ def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
916
+ """
917
+
918
+ :param est_mask: torch.Tensor, shape: [b, n+2, t]
919
+ :param clean:
920
+ :param noisy:
921
+ :return:
922
+ """
923
+ clean_stft = self.stft(clean)
924
+ clean_re = clean_stft[:, :self.freq_bins, :]
925
+ clean_im = clean_stft[:, self.freq_bins:, :]
926
+
927
+ noisy_stft = self.stft(noisy)
928
+ noisy_re = noisy_stft[:, :self.freq_bins, :]
929
+ noisy_im = noisy_stft[:, self.freq_bins:, :]
930
+
931
+ noisy_power = noisy_re ** 2 + noisy_im ** 2
932
+
933
+ sr = clean_re
934
+ yr = noisy_re
935
+ si = clean_im
936
+ yi = noisy_im
937
+ y_pow = noisy_power
938
+ # (Sr * Yr + Si * Yi) / (Y_pow + 1e-8)
939
+ gth_mask_re = (sr * yr + si * yi) / (y_pow + self.eps)
940
+ # (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)
941
+ gth_mask_im = (sr * yr - si * yi) / (y_pow + self.eps)
942
+
943
+ gth_mask_re[gth_mask_re > 2] = 1
944
+ gth_mask_re[gth_mask_re < -2] = -1
945
+ gth_mask_im[gth_mask_im > 2] = 1
946
+ gth_mask_im[gth_mask_im < -2] = -1
947
+
948
+ mask_re = est_mask[:, :self.freq_bins, :]
949
+ mask_im = est_mask[:, self.freq_bins:, :]
950
+
951
+ loss_re = F.mse_loss(gth_mask_re, mask_re)
952
+ loss_im = F.mse_loss(gth_mask_im, mask_im)
953
+
954
+ loss = loss_re + loss_im
955
+ return loss
956
 
957
 
958
  class DfNetPretrainedModel(DfNet):
 
1003
 
1004
  def main():
1005
 
 
 
 
 
 
 
 
 
1006
  config = DfNetConfig()
1007
  model = DfNetPretrainedModel(config=config)
1008
 
1009
+ noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
 
 
1010
 
1011
+ output = model.forward(noisy)
1012
  print(output[1].shape)
1013
  return
1014