HoneyTian commited on
Commit
cba47e4
·
1 Parent(s): 1b032b9

add frcrn model

Browse files
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -300,7 +300,7 @@ def main():
300
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
301
  # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
302
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
303
- loss = 0.2 * ae_loss + 0.2 * neg_si_snr_loss + 1.0 * mr_stft_loss + 0.3 * neg_stoi_loss + 0.5 * pesq_loss
304
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
305
  logger.info(f"find nan or inf in loss.")
306
  continue
@@ -381,7 +381,8 @@ def main():
381
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
382
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
383
  # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
384
- loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
 
385
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
386
  logger.info(f"find nan or inf in loss.")
387
  continue
 
300
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
301
  # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
302
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
303
+ loss = 0.1 * ae_loss + 0.1 * neg_si_snr_loss + 1.0 * mr_stft_loss + 0.2 * neg_stoi_loss + 0.2 * pesq_loss
304
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
305
  logger.info(f"find nan or inf in loss.")
306
  continue
 
381
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
382
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
383
  # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
384
+ # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
385
+ loss = 0.1 * ae_loss + 0.1 * neg_si_snr_loss + 1.0 * mr_stft_loss + 0.2 * neg_stoi_loss + 0.2 * pesq_loss
386
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
387
  logger.info(f"find nan or inf in loss.")
388
  continue
examples/frcrn/run.sh ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
+ --max_epochs 400
10
+
11
+
12
+ END
13
+
14
+
15
+ # params
16
+ system_version="windows";
17
+ verbose=true;
18
+ stage=0 # start from 0 if you need to start from data preparation
19
+ stop_stage=9
20
+
21
+ work_dir="$(pwd)"
22
+ file_folder_name=file_folder_name
23
+ final_model_name=final_model_name
24
+ config_file="yaml/config.yaml"
25
+ limit=10
26
+
27
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
28
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
29
+
30
+ max_count=10000000
31
+
32
+ nohup_name=nohup.out
33
+
34
+ # model params
35
+ batch_size=64
36
+ max_epochs=200
37
+ save_top_k=10
38
+ patience=5
39
+
40
+
41
+ # parse options
42
+ while true; do
43
+ [ -z "${1:-}" ] && break; # break if there are no arguments
44
+ case "$1" in
45
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
46
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
47
+ old_value="(eval echo \\$$name)";
48
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
49
+ was_bool=true;
50
+ else
51
+ was_bool=false;
52
+ fi
53
+
54
+ # Set the variable to the right value-- the escaped quotes make it work if
55
+ # the option had spaces, like --cmd "queue.pl -sync y"
56
+ eval "${name}=\"$2\"";
57
+
58
+ # Check that Boolean-valued arguments are really Boolean.
59
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
60
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
61
+ exit 1;
62
+ fi
63
+ shift 2;
64
+ ;;
65
+
66
+ *) break;
67
+ esac
68
+ done
69
+
70
+ file_dir="${work_dir}/${file_folder_name}"
71
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
72
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
73
+
74
+ train_dataset="${file_dir}/train.jsonl"
75
+ valid_dataset="${file_dir}/valid.jsonl"
76
+
77
+ $verbose && echo "system_version: ${system_version}"
78
+ $verbose && echo "file_folder_name: ${file_folder_name}"
79
+
80
+ if [ $system_version == "windows" ]; then
81
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
82
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
83
+ #source /data/local/bin/nx_denoise/bin/activate
84
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
85
+ fi
86
+
87
+
88
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
89
+ $verbose && echo "stage 1: prepare data"
90
+ cd "${work_dir}" || exit 1
91
+ python3 step_1_prepare_data.py \
92
+ --file_dir "${file_dir}" \
93
+ --noise_dir "${noise_dir}" \
94
+ --speech_dir "${speech_dir}" \
95
+ --train_dataset "${train_dataset}" \
96
+ --valid_dataset "${valid_dataset}" \
97
+ --max_count "${max_count}" \
98
+
99
+ fi
100
+
101
+
102
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
103
+ $verbose && echo "stage 2: train model"
104
+ cd "${work_dir}" || exit 1
105
+ python3 step_2_train_model.py \
106
+ --train_dataset "${train_dataset}" \
107
+ --valid_dataset "${valid_dataset}" \
108
+ --serialization_dir "${file_dir}" \
109
+ --config_file "${config_file}" \
110
+
111
+ fi
112
+
113
+
114
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
115
+ $verbose && echo "stage 3: test model"
116
+ cd "${work_dir}" || exit 1
117
+ python3 step_3_evaluation.py \
118
+ --valid_dataset "${valid_dataset}" \
119
+ --model_dir "${file_dir}/best" \
120
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
121
+ --limit "${limit}" \
122
+
123
+ fi
124
+
125
+
126
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
127
+ $verbose && echo "stage 4: collect files"
128
+ cd "${work_dir}" || exit 1
129
+
130
+ mkdir -p ${final_model_dir}
131
+
132
+ cp "${file_dir}/best"/* "${final_model_dir}"
133
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
134
+
135
+ cd "${final_model_dir}/.." || exit 1;
136
+
137
+ if [ -e "${final_model_name}.zip" ]; then
138
+ rm -rf "${final_model_name}_backup.zip"
139
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
140
+ fi
141
+
142
+ zip -r "${final_model_name}.zip" "${final_model_name}"
143
+ rm -rf "${final_model_name}"
144
+
145
+ fi
146
+
147
+
148
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
149
+ $verbose && echo "stage 5: clear file_dir"
150
+ cd "${work_dir}" || exit 1
151
+
152
+ rm -rf "${file_dir}";
153
+
154
+ fi
examples/frcrn/step_1_prepare_data.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--file_dir", default="./", type=str)
21
+
22
+ parser.add_argument(
23
+ "--noise_dir",
24
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--speech_dir",
29
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
30
+ type=str
31
+ )
32
+
33
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
34
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
35
+
36
+ parser.add_argument("--duration", default=4.0, type=float)
37
+ parser.add_argument("--min_snr_db", default=-10, type=float)
38
+ parser.add_argument("--max_snr_db", default=20, type=float)
39
+
40
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
41
+
42
+ parser.add_argument("--max_count", default=10000, type=int)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def filename_generator(data_dir: str):
49
+ data_dir = Path(data_dir)
50
+ for filename in data_dir.glob("**/*.wav"):
51
+ yield filename.as_posix()
52
+
53
+
54
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
55
+ data_dir = Path(data_dir)
56
+ for epoch_idx in range(max_epoch):
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ if np.sum(signal[begin: begin+win_size]) == 0:
71
+ continue
72
+ row = {
73
+ "epoch_idx": epoch_idx,
74
+ "filename": filename.as_posix(),
75
+ "raw_duration": round(raw_duration, 4),
76
+ "offset": round(begin / sample_rate, 4),
77
+ "duration": round(duration, 4),
78
+ }
79
+ yield row
80
+
81
+
82
+ def main():
83
+ args = get_args()
84
+
85
+ file_dir = Path(args.file_dir)
86
+ file_dir.mkdir(exist_ok=True)
87
+
88
+ noise_dir = Path(args.noise_dir)
89
+ speech_dir = Path(args.speech_dir)
90
+
91
+ noise_generator = target_second_signal_generator(
92
+ noise_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate,
95
+ max_epoch=100000,
96
+ )
97
+ speech_generator = target_second_signal_generator(
98
+ speech_dir.as_posix(),
99
+ duration=args.duration,
100
+ sample_rate=args.target_sample_rate,
101
+ max_epoch=1,
102
+ )
103
+
104
+ dataset = list()
105
+
106
+ count = 0
107
+ process_bar = tqdm(desc="build dataset excel")
108
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
+ for noise, speech in zip(noise_generator, speech_generator):
110
+ if count >= args.max_count:
111
+ break
112
+
113
+ noise_filename = noise["filename"]
114
+ noise_raw_duration = noise["raw_duration"]
115
+ noise_offset = noise["offset"]
116
+ noise_duration = noise["duration"]
117
+
118
+ speech_filename = speech["filename"]
119
+ speech_raw_duration = speech["raw_duration"]
120
+ speech_offset = speech["offset"]
121
+ speech_duration = speech["duration"]
122
+
123
+ random1 = random.random()
124
+ random2 = random.random()
125
+
126
+ row = {
127
+ "noise_filename": noise_filename,
128
+ "noise_raw_duration": noise_raw_duration,
129
+ "noise_offset": noise_offset,
130
+ "noise_duration": noise_duration,
131
+
132
+ "speech_filename": speech_filename,
133
+ "speech_raw_duration": speech_raw_duration,
134
+ "speech_offset": speech_offset,
135
+ "speech_duration": speech_duration,
136
+
137
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
138
+
139
+ "random1": random1,
140
+ }
141
+ row = json.dumps(row, ensure_ascii=False)
142
+ if random2 < (1 / 300 / 1):
143
+ fvalid.write(f"{row}\n")
144
+ else:
145
+ ftrain.write(f"{row}\n")
146
+
147
+ count += 1
148
+ duration_seconds = count * args.duration
149
+ duration_hours = duration_seconds / 3600
150
+
151
+ process_bar.update(n=1)
152
+ process_bar.set_postfix({
153
+ # "duration_seconds": round(duration_seconds, 4),
154
+ "duration_hours": round(duration_hours, 4),
155
+
156
+ })
157
+
158
+ return
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
examples/frcrn/step_2_train_model.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import logging
6
+ from logging.handlers import TimedRotatingFileHandler
7
+ import os
8
+ import platform
9
+ from pathlib import Path
10
+ import random
11
+ import sys
12
+ import shutil
13
+ from typing import List
14
+
15
+ pwd = os.path.abspath(os.path.dirname(__file__))
16
+ sys.path.append(os.path.join(pwd, "../../"))
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn import functional as F
22
+ from torch.utils.data.dataloader import DataLoader
23
+ from tqdm import tqdm
24
+
25
+ from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
26
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
27
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
28
+ from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
29
+ from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRN, FRCRNPretrainedModel
30
+
31
+
32
+ def get_args():
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
35
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
36
+
37
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
38
+ parser.add_argument("--patience", default=5, type=int)
39
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
40
+
41
+ parser.add_argument("--config_file", default="config.yaml", type=str)
42
+
43
+ args = parser.parse_args()
44
+ return args
45
+
46
+
47
+ def logging_config(file_dir: str):
48
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
49
+
50
+ logging.basicConfig(format=fmt,
51
+ datefmt="%m/%d/%Y %H:%M:%S",
52
+ level=logging.INFO)
53
+ file_handler = TimedRotatingFileHandler(
54
+ filename=os.path.join(file_dir, "main.log"),
55
+ encoding="utf-8",
56
+ when="D",
57
+ interval=1,
58
+ backupCount=7
59
+ )
60
+ file_handler.setLevel(logging.INFO)
61
+ file_handler.setFormatter(logging.Formatter(fmt))
62
+ logger = logging.getLogger(__name__)
63
+ logger.addHandler(file_handler)
64
+
65
+ return logger
66
+
67
+
68
+ class CollateFunction(object):
69
+ def __init__(self):
70
+ pass
71
+
72
+ def __call__(self, batch: List[dict]):
73
+ clean_audios = list()
74
+ noisy_audios = list()
75
+
76
+ for sample in batch:
77
+ # noise_wave: torch.Tensor = sample["noise_wave"]
78
+ clean_audio: torch.Tensor = sample["speech_wave"]
79
+ noisy_audio: torch.Tensor = sample["mix_wave"]
80
+ # snr_db: float = sample["snr_db"]
81
+
82
+ clean_audios.append(clean_audio)
83
+ noisy_audios.append(noisy_audio)
84
+
85
+ clean_audios = torch.stack(clean_audios)
86
+ noisy_audios = torch.stack(noisy_audios)
87
+
88
+ # assert
89
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
90
+ raise AssertionError("nan or inf in clean_audios")
91
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
92
+ raise AssertionError("nan or inf in noisy_audios")
93
+ return clean_audios, noisy_audios
94
+
95
+
96
+ collate_fn = CollateFunction()
97
+
98
+
99
+ def main():
100
+ args = get_args()
101
+
102
+ config = FRCRNConfig.from_pretrained(
103
+ pretrained_model_name_or_path=args.config_file,
104
+ )
105
+
106
+ serialization_dir = Path(args.serialization_dir)
107
+ serialization_dir.mkdir(parents=True, exist_ok=True)
108
+
109
+ logger = logging_config(serialization_dir)
110
+
111
+ random.seed(config.seed)
112
+ np.random.seed(config.seed)
113
+ torch.manual_seed(config.seed)
114
+ logger.info(f"set seed: {config.seed}")
115
+
116
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
117
+ n_gpu = torch.cuda.device_count()
118
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
119
+
120
+ # datasets
121
+ train_dataset = DenoiseJsonlDataset(
122
+ jsonl_file=args.train_dataset,
123
+ expected_sample_rate=config.sample_rate,
124
+ max_wave_value=32768.0,
125
+ min_snr_db=config.min_snr_db,
126
+ max_snr_db=config.max_snr_db,
127
+ # skip=225000,
128
+ )
129
+ valid_dataset = DenoiseJsonlDataset(
130
+ jsonl_file=args.valid_dataset,
131
+ expected_sample_rate=config.sample_rate,
132
+ max_wave_value=32768.0,
133
+ min_snr_db=config.min_snr_db,
134
+ max_snr_db=config.max_snr_db,
135
+ )
136
+ train_data_loader = DataLoader(
137
+ dataset=train_dataset,
138
+ batch_size=args.batch_size,
139
+ # shuffle=True,
140
+ sampler=None,
141
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
142
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
143
+ collate_fn=collate_fn,
144
+ pin_memory=False,
145
+ prefetch_factor=2,
146
+ )
147
+ valid_data_loader = DataLoader(
148
+ dataset=valid_dataset,
149
+ batch_size=args.batch_size,
150
+ # shuffle=True,
151
+ sampler=None,
152
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
153
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
154
+ collate_fn=collate_fn,
155
+ pin_memory=False,
156
+ prefetch_factor=2,
157
+ )
158
+
159
+ # models
160
+ logger.info(f"prepare models. config_file: {args.config_file}")
161
+ model = FRCRNPretrainedModel(config).to(device)
162
+ model.to(device)
163
+ model.train()
164
+
165
+ # optimizer
166
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
167
+ optimizer = torch.optim.AdamW(model.get_params(weight_decay=config.weight_decay), config.lr)
168
+
169
+ # resume training
170
+ last_step_idx = -1
171
+ last_epoch = -1
172
+ for step_idx_str in serialization_dir.glob("steps-*"):
173
+ step_idx_str = Path(step_idx_str)
174
+ step_idx = step_idx_str.stem.split("-")[1]
175
+ step_idx = int(step_idx)
176
+ if step_idx > last_step_idx:
177
+ last_step_idx = step_idx
178
+ last_epoch = 1
179
+
180
+ if last_step_idx != -1:
181
+ logger.info(f"resume from steps-{last_step_idx}.")
182
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
183
+ optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
184
+
185
+ logger.info(f"load state dict for model.")
186
+ with open(model_pt.as_posix(), "rb") as f:
187
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
188
+ model.load_state_dict(state_dict, strict=True)
189
+
190
+ logger.info(f"load state dict for optimizer.")
191
+ with open(optimizer_pth.as_posix(), "rb") as f:
192
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
193
+ optimizer.load_state_dict(state_dict)
194
+
195
+ if config.lr_scheduler == "CosineAnnealingLR":
196
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
197
+ optimizer,
198
+ last_epoch=last_epoch,
199
+ # T_max=10 * config.eval_steps,
200
+ # eta_min=0.01 * config.lr,
201
+ **config.lr_scheduler_kwargs,
202
+ )
203
+ elif config.lr_scheduler == "MultiStepLR":
204
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
205
+ optimizer,
206
+ last_epoch=last_epoch,
207
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
208
+ )
209
+ else:
210
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
211
+
212
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
213
+
214
+ # training loop
215
+
216
+ # state
217
+ average_pesq_score = 1000000000
218
+ average_loss = 1000000000
219
+ average_neg_si_snr_loss = 1000000000
220
+ average_mag_loss = 1000000000
221
+ average_pha_loss = 1000000000
222
+
223
+ model_list = list()
224
+ best_epoch_idx = None
225
+ best_step_idx = None
226
+ best_metric = None
227
+ patience_count = 0
228
+
229
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
230
+
231
+ logger.info("training")
232
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
233
+ # train
234
+ model.train()
235
+
236
+ total_pesq_score = 0.
237
+ total_loss = 0.
238
+ total_neg_si_snr_loss = 0.
239
+ total_map_loss = 0.
240
+ total_pha_loss = 0.
241
+ total_batches = 0.
242
+
243
+ progress_bar_train = tqdm(
244
+ initial=step_idx,
245
+ desc="Training; epoch-{}".format(epoch_idx),
246
+ )
247
+ for train_batch in train_data_loader:
248
+ clean_audios, noisy_audios = train_batch
249
+ clean_audios: torch.Tensor = clean_audios.to(device)
250
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
251
+
252
+ est_spec, est_wav, est_mask = model.forward(noisy_audios)
253
+ denoise_audios = est_wav
254
+
255
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
256
+ map_loss, pha_loss = model.mag_pha_loss_fn(est_mask, clean_audios, noisy_audios)
257
+
258
+ loss = 0.5 * map_loss + 0.5 * pha_loss + 0.5 * neg_si_snr_loss
259
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
260
+ logger.info(f"find nan or inf in loss.")
261
+ continue
262
+
263
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
264
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
265
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
266
+
267
+ optimizer.zero_grad()
268
+ loss.backward()
269
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
270
+ optimizer.step()
271
+ lr_scheduler.step()
272
+
273
+ total_pesq_score += pesq_score
274
+ total_loss += loss.item()
275
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
276
+ total_map_loss += map_loss.item()
277
+ total_pha_loss += pha_loss.item()
278
+ total_batches += 1
279
+
280
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
281
+ average_loss = round(total_loss / total_batches, 4)
282
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
283
+ average_mag_loss = round(total_map_loss / total_batches, 4)
284
+ average_pha_loss = round(total_pha_loss / total_batches, 4)
285
+
286
+ progress_bar_train.update(1)
287
+ progress_bar_train.set_postfix({
288
+ "lr": lr_scheduler.get_last_lr()[0],
289
+ "pesq_score": average_pesq_score,
290
+ "loss": average_loss,
291
+ "neg_si_snr_loss": average_neg_si_snr_loss,
292
+ "mag_loss": average_mag_loss,
293
+ "pha_loss": average_pha_loss,
294
+ })
295
+
296
+ # evaluation
297
+ step_idx += 1
298
+ if step_idx % config.eval_steps == 0:
299
+ with torch.no_grad():
300
+ torch.cuda.empty_cache()
301
+
302
+ total_pesq_score = 0.
303
+ total_loss = 0.
304
+ total_neg_si_snr_loss = 0.
305
+ total_map_loss = 0.
306
+ total_pha_loss = 0.
307
+ total_batches = 0.
308
+
309
+ progress_bar_train.close()
310
+ progress_bar_eval = tqdm(
311
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
312
+ )
313
+ for eval_batch in valid_data_loader:
314
+ clean_audios, noisy_audios = eval_batch
315
+ clean_audios = clean_audios.to(device)
316
+ noisy_audios = noisy_audios.to(device)
317
+
318
+ est_spec, est_wav, est_mask = model.forward(noisy_audios)
319
+ denoise_audios = est_wav
320
+
321
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
322
+ map_loss, pha_loss = model.mag_pha_loss_fn(est_mask, clean_audios, noisy_audios)
323
+
324
+ loss = 0.5 * map_loss + 0.5 * pha_loss + 0.5 * neg_si_snr_loss
325
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
326
+ logger.info(f"find nan or inf in loss.")
327
+ continue
328
+
329
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
330
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
331
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
332
+
333
+ total_pesq_score += pesq_score
334
+ total_loss += loss.item()
335
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
336
+ total_map_loss += map_loss.item()
337
+ total_pha_loss += pha_loss.item()
338
+ total_batches += 1
339
+
340
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
341
+ average_loss = round(total_loss / total_batches, 4)
342
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
343
+ average_mag_loss = round(total_map_loss / total_batches, 4)
344
+ average_pha_loss = round(total_pha_loss / total_batches, 4)
345
+
346
+ progress_bar_eval.update(1)
347
+ progress_bar_eval.set_postfix({
348
+ "lr": lr_scheduler.get_last_lr()[0],
349
+ "pesq_score": average_pesq_score,
350
+ "loss": average_loss,
351
+ "neg_si_snr_loss": average_neg_si_snr_loss,
352
+ "mag_loss": average_mag_loss,
353
+ "pha_loss": average_pha_loss,
354
+ })
355
+
356
+ total_pesq_score = 0.
357
+ total_loss = 0.
358
+ total_neg_si_snr_loss = 0.
359
+ total_map_loss = 0.
360
+ total_pha_loss = 0.
361
+ total_batches = 0.
362
+
363
+ progress_bar_eval.close()
364
+ progress_bar_train = tqdm(
365
+ initial=progress_bar_train.n,
366
+ postfix=progress_bar_train.postfix,
367
+ desc=progress_bar_train.desc,
368
+ )
369
+
370
+ # save path
371
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
372
+ save_dir.mkdir(parents=True, exist_ok=False)
373
+
374
+ # save models
375
+ model.save_pretrained(save_dir.as_posix())
376
+
377
+ model_list.append(save_dir)
378
+ if len(model_list) >= args.num_serialized_models_to_keep:
379
+ model_to_delete: Path = model_list.pop(0)
380
+ shutil.rmtree(model_to_delete.as_posix())
381
+
382
+ # save optim
383
+ torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
384
+
385
+ # save metric
386
+ if best_metric is None:
387
+ best_epoch_idx = epoch_idx
388
+ best_step_idx = step_idx
389
+ best_metric = average_pesq_score
390
+ elif average_pesq_score > best_metric:
391
+ # great is better.
392
+ best_epoch_idx = epoch_idx
393
+ best_step_idx = step_idx
394
+ best_metric = average_pesq_score
395
+ else:
396
+ pass
397
+
398
+ metrics = {
399
+ "epoch_idx": epoch_idx,
400
+ "best_epoch_idx": best_epoch_idx,
401
+ "best_step_idx": best_step_idx,
402
+ "pesq_score": average_pesq_score,
403
+ "loss": average_loss,
404
+ "neg_si_snr_loss": average_neg_si_snr_loss,
405
+ "mag_loss": average_mag_loss,
406
+ "pha_loss": average_pha_loss,
407
+ }
408
+ metrics_filename = save_dir / "metrics_epoch.json"
409
+ with open(metrics_filename, "w", encoding="utf-8") as f:
410
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
411
+
412
+ # save best
413
+ best_dir = serialization_dir / "best"
414
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
415
+ if best_dir.exists():
416
+ shutil.rmtree(best_dir)
417
+ shutil.copytree(save_dir, best_dir)
418
+
419
+ # early stop
420
+ early_stop_flag = False
421
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
422
+ patience_count = 0
423
+ else:
424
+ patience_count += 1
425
+ if patience_count >= args.patience:
426
+ early_stop_flag = True
427
+
428
+ # early stop
429
+ if early_stop_flag:
430
+ break
431
+
432
+ return
433
+
434
+
435
+ if __name__ == "__main__":
436
+ main()
examples/frcrn/yaml/config.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model_name: "frcrn"
3
+
4
+ num_gpus: -1
5
+
6
+ lr: 0.001
7
+ max_epochs: 100
8
+ weight_decay: 1.0e-05
9
+ clip_grad_norm: 10.0
10
+ seed: 1234
11
+
12
+ sample_rate: 8000
13
+ segment_size: 32000
14
+ nfft: 512
15
+ win_size: 512
16
+ hop_size: 256
17
+ win_type: hann
18
+
19
+ use_complex_networks: true
20
+ model_depth: 20
21
+ model_complexity: 45
22
+
23
+ num_workers: 4
24
+ batch_size: 4
toolbox/torchaudio/losses/irm.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import List
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+
9
+
10
+ class CIRMLoss(nn.Module):
11
+ def __init__(self,
12
+ n_fft: int = 512,
13
+ win_size: int = 512,
14
+ hop_size: int = 256,
15
+ center: bool = True,
16
+ eps: float = 1e-8,
17
+ reduction: str = "mean",
18
+ ):
19
+ super(CIRMLoss, self).__init__()
20
+ self.n_fft = n_fft
21
+ self.win_size = win_size
22
+ self.hop_size = hop_size
23
+ self.center = center
24
+ self.eps = eps
25
+ self.reduction = reduction
26
+
27
+ self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)
28
+
29
+ if reduction not in ("sum", "mean"):
30
+ raise AssertionError(f"param reduction must be sum or mean.")
31
+
32
+ def forward(self, clean: torch.Tensor, noisy: torch.Tensor, mask_real: torch.Tensor, mask_imag: torch.Tensor):
33
+ """
34
+ :param clean: waveform
35
+ :param noisy: waveform
36
+ :param mask_real: shape: [b, f, t]
37
+ :param mask_imag: shape: [b, f, t]
38
+ :return:
39
+ """
40
+ if noisy.shape != clean.shape:
41
+ raise AssertionError("Input signals must have the same shape")
42
+
43
+ # clean_stft, noisy_stft shape: [b, f, t]
44
+ clean_stft = torch.stft(
45
+ clean,
46
+ n_fft=self.n_fft,
47
+ win_length=self.win_size,
48
+ hop_length=self.hop_size,
49
+ window=self.window,
50
+ center=self.center,
51
+ pad_mode="reflect",
52
+ normalized=False,
53
+ return_complex=True
54
+ )
55
+ noisy_stft = torch.stft(
56
+ noisy,
57
+ n_fft=self.n_fft,
58
+ win_length=self.win_size,
59
+ hop_length=self.hop_size,
60
+ window=self.window,
61
+ center=self.center,
62
+ pad_mode="reflect",
63
+ normalized=False,
64
+ return_complex=True
65
+ )
66
+
67
+ # [b, f, t]
68
+ clean_stft_spec_real = torch.real(clean_stft)
69
+ clean_stft_spec_imag = torch.imag(clean_stft)
70
+ noisy_stft_spec_real = torch.real(noisy_stft)
71
+ noisy_stft_spec_imag = torch.imag(noisy_stft)
72
+ noisy_power = noisy_stft_spec_real ** 2 + noisy_stft_spec_imag ** 2
73
+
74
+ sr = clean_stft_spec_real
75
+ yr = noisy_stft_spec_real
76
+ si = clean_stft_spec_imag
77
+ yi = noisy_stft_spec_imag
78
+ y_pow = noisy_power
79
+ # (Sr * Yr + Si * Yi) / (Y_pow + 1e-8)
80
+ gth_mask_real = (sr * yr + si * yi) / (y_pow + self.eps)
81
+ # (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)
82
+ gth_mask_imag = (sr * yr - si * yi) / (y_pow + self.eps)
83
+
84
+ gth_mask_real[gth_mask_real > 2] = 1
85
+ gth_mask_real[gth_mask_real < -2] = -1
86
+ gth_mask_imag[gth_mask_imag > 2] = 1
87
+ gth_mask_imag[gth_mask_imag < -2] = -1
88
+
89
+ amp_loss = F.mse_loss(gth_mask_real, mask_real)
90
+ phase_loss = F.mse_loss(gth_mask_imag, mask_imag)
91
+
92
+ loss = amp_loss + phase_loss
93
+ return loss
94
+
95
+
96
+ def main():
97
+ batch_size = 2
98
+ signal_length = 16000
99
+ estimated_signal = torch.randn(batch_size, signal_length)
100
+ target_signal = torch.randn(batch_size, signal_length)
101
+
102
+ loss_fn = CIRMLoss()
103
+
104
+ loss = loss_fn.forward(estimated_signal, target_signal)
105
+ print(f"loss: {loss.item()}")
106
+
107
+ return
108
+
109
+
110
+ if __name__ == "__main__":
111
+ main()
toolbox/torchaudio/losses/spectral.py CHANGED
@@ -346,6 +346,76 @@ class MultiResolutionSTFTLoss(torch.nn.Module):
346
  return loss
347
 
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  def main():
350
  batch_size = 2
351
  signal_length = 16000
@@ -354,7 +424,8 @@ def main():
354
 
355
  # loss_fn = LSDLoss()
356
  # loss_fn = ComplexSpectralLoss()
357
- loss_fn = MultiResolutionSTFTLoss()
 
358
 
359
  loss = loss_fn.forward(estimated_signal, target_signal)
360
  print(f"loss: {loss.item()}")
 
346
  return loss
347
 
348
 
349
+ class WeightedMagnitudePhaseLoss(nn.Module):
350
+ def __init__(self,
351
+ n_fft: int = 1024,
352
+ win_size: int = 600,
353
+ hop_size: int = 120,
354
+ center: bool = True,
355
+ reduction: str = "mean",
356
+ mag_weight: float = 0.9,
357
+ pha_weight: float = 0.3,
358
+ ):
359
+ super(WeightedMagnitudePhaseLoss, self).__init__()
360
+ self.n_fft = n_fft
361
+ self.win_size = win_size
362
+ self.hop_size = hop_size
363
+ self.center = center
364
+ self.reduction = reduction
365
+
366
+ self.mag_weight = mag_weight
367
+ self.pha_weight = pha_weight
368
+
369
+ self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)
370
+
371
+ def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
372
+ """
373
+ :param denoise:
374
+ :param clean:
375
+ :return:
376
+ """
377
+ if denoise.shape != clean.shape:
378
+ raise AssertionError("Input signals must have the same shape")
379
+
380
+ # denoise_stft, clean_stft shape: [b, f, t]
381
+ denoise_stft = torch.stft(
382
+ denoise,
383
+ n_fft=self.n_fft,
384
+ win_length=self.win_size,
385
+ hop_length=self.hop_size,
386
+ window=self.window,
387
+ center=self.center,
388
+ pad_mode="reflect",
389
+ normalized=False,
390
+ return_complex=True
391
+ )
392
+ clean_stft = torch.stft(
393
+ clean,
394
+ n_fft=self.n_fft,
395
+ win_length=self.win_size,
396
+ hop_length=self.hop_size,
397
+ window=self.window,
398
+ center=self.center,
399
+ pad_mode="reflect",
400
+ normalized=False,
401
+ return_complex=True
402
+ )
403
+
404
+ denoise_stft_spec = torch.view_as_real(denoise_stft)
405
+ denoise_mag = torch.sqrt(denoise_stft_spec.pow(2).sum(-1) + 1e-9)
406
+ denoise_pha = torch.atan2(denoise_stft_spec[:, :, :, 1] + 1e-10, denoise_stft_spec[:, :, :, 0] + 1e-5)
407
+
408
+ clean_stft_spec = torch.view_as_real(clean_stft)
409
+ clean_mag = torch.sqrt(clean_stft_spec.pow(2).sum(-1) + 1e-9)
410
+ clean_pha = torch.atan2(clean_stft_spec[:, :, :, 1] + 1e-10, clean_stft_spec[:, :, :, 0] + 1e-5)
411
+
412
+ mag_loss = F.mse_loss(denoise_mag, clean_mag, reduction=self.reduction)
413
+ pha_loss = F.mse_loss(denoise_pha, clean_pha, reduction=self.reduction)
414
+
415
+ loss = self.mag_weight * mag_loss + self.pha_weight * pha_loss
416
+ return loss
417
+
418
+
419
  def main():
420
  batch_size = 2
421
  signal_length = 16000
 
424
 
425
  # loss_fn = LSDLoss()
426
  # loss_fn = ComplexSpectralLoss()
427
+ # loss_fn = MultiResolutionSTFTLoss()
428
+ loss_fn = WeightedMagnitudePhaseLoss()
429
 
430
  loss = loss_fn.forward(estimated_signal, target_signal)
431
  print(f"loss: {loss.item()}")
toolbox/torchaudio/models/frcrn/complex_nn.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Union, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from toolbox.torchaudio.models.frcrn.uni_deep_fsmn import UniDeepFsmn
9
+
10
+
11
+ class ComplexUniDeepFsmn(nn.Module):
12
+ def __init__(self, input_dim: int, hidden_size: int, lorder: int = 20):
13
+ super(ComplexUniDeepFsmn, self).__init__()
14
+
15
+ self.fsmn_re_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder)
16
+ self.fsmn_im_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder)
17
+ self.fsmn_re_l2 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder)
18
+ self.fsmn_im_l2 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder)
19
+
20
+ def forward(self, x: torch.Tensor):
21
+ """
22
+ :param x: torch.Tensor, shape: [b, c, h, t, 2]
23
+ :return: torch.Tensor, shape: [b, h, t, 2]
24
+ """
25
+ b, c, h, t, d = x.size()
26
+ x = torch.reshape(x, shape=(b, c * h, t, d))
27
+ # x shape: [b, h', t, 2]
28
+ x = torch.transpose(x, dim0=1, dim1=2)
29
+ # x shape: [b, t, h', 2]
30
+
31
+ real_l1 = self.fsmn_re_l1(x[..., 0]) - self.fsmn_im_l1(x[..., 1])
32
+ imaginary_l1 = self.fsmn_re_l1(x[..., 1]) + self.fsmn_im_l1(x[..., 0])
33
+ # real, image shape: [b, t, h']
34
+
35
+ real = self.fsmn_re_l2(real_l1) - self.fsmn_im_l2(imaginary_l1)
36
+ imaginary = self.fsmn_re_l2(imaginary_l1) + self.fsmn_im_l2(real_l1)
37
+ # real, image shape: [b, t, h']
38
+
39
+ output = torch.stack(tensors=(real, imaginary), dim=-1)
40
+ # output shape: [b, t, h', 2]
41
+ output = torch.transpose(output, dim0=1, dim1=2)
42
+ # output shape: [b, h', t, 2]
43
+ output = torch.reshape(output, shape=(b, c, h, t, d))
44
+ # output shape: [b, c, h, t, 2]
45
+ return output
46
+
47
+
48
+ class ComplexUniDeepFsmnL1(nn.Module):
49
+ def __init__(self, input_dim: int, hidden_size: int, lorder: int = 20):
50
+ super(ComplexUniDeepFsmnL1, self).__init__()
51
+ self.fsmn_re_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder)
52
+ self.fsmn_im_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder)
53
+
54
+ def forward(self, x: torch.Tensor):
55
+ b, c, h, t, d = x.size()
56
+ x = torch.transpose(x, dim0=1, dim1=3)
57
+ # x shape: [b, t, h, c, 2]
58
+ x = torch.reshape(x, shape=(b * t, h, c, d))
59
+ # x shape: [b*t, h, c, 2]
60
+
61
+ real = self.fsmn_re_l1(x[..., 0]) - self.fsmn_im_l1(x[..., 1])
62
+ imaginary = self.fsmn_re_l1(x[..., 1]) + self.fsmn_im_l1(x[..., 0])
63
+ # real, image shape: [b*t, h, c]
64
+
65
+ output = torch.stack(tensors=(real, imaginary), dim=-1)
66
+ # output shape: [b*t, h, c, 2]
67
+ output = torch.reshape(output, shape=(b, t, h, c, d))
68
+ # output shape: [b, t, h, c, 2]
69
+ output = torch.transpose(output, dim0=1, dim1=3)
70
+ # output shape: [b, c, h, t, 2]
71
+ return output
72
+
73
+
74
+ class ComplexConv2d(nn.Module):
75
+ # https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py
76
+ def __init__(self,
77
+ in_channels: int,
78
+ out_channels: int,
79
+ kernel_size: Union[int, Tuple[int, int]],
80
+ stride: Union[int, Tuple[int, int]] = 1,
81
+ padding: Union[int, Tuple[int, int]] = 0,
82
+ dilation: Union[int, Tuple[int, int]] = 1,
83
+ groups: int = 1,
84
+ bias: bool = True,
85
+ **kwargs
86
+ ):
87
+ super().__init__()
88
+
89
+ # Model components
90
+ self.conv_re = nn.Conv2d(
91
+ in_channels,
92
+ out_channels,
93
+ kernel_size,
94
+ stride=stride,
95
+ padding=padding,
96
+ dilation=dilation,
97
+ groups=groups,
98
+ bias=bias,
99
+ **kwargs
100
+ )
101
+ self.conv_im = nn.Conv2d(
102
+ in_channels,
103
+ out_channels,
104
+ kernel_size,
105
+ stride=stride,
106
+ padding=padding,
107
+ dilation=dilation,
108
+ groups=groups,
109
+ bias=bias,
110
+ **kwargs
111
+ )
112
+
113
+ def forward(self, x: torch.Tensor):
114
+ """
115
+
116
+ :param x: torch.Tensor, shape: [b, c, h, w, 2]
117
+ :return:
118
+ """
119
+ real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1])
120
+ imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0])
121
+
122
+ output = torch.stack((real, imaginary), dim=-1)
123
+ return output
124
+
125
+
126
+ class ComplexConvTranspose2d(nn.Module):
127
+ def __init__(self,
128
+ in_channels: int,
129
+ out_channels: int,
130
+ kernel_size: Union[int, Tuple[int, int]],
131
+ stride: Union[int, Tuple[int, int]] = 1,
132
+ padding: Union[int, Tuple[int, int]] = 0,
133
+ output_padding: Union[int, Tuple[int, int]] = 0,
134
+ dilation: Union[int, Tuple[int, int]] = 1,
135
+ groups: int = 1,
136
+ bias=True,
137
+ **kwargs
138
+ ):
139
+ super().__init__()
140
+
141
+ # Model components
142
+ self.tconv_re = nn.ConvTranspose2d(
143
+ in_channels,
144
+ out_channels,
145
+ kernel_size=kernel_size,
146
+ stride=stride,
147
+ padding=padding,
148
+ output_padding=output_padding,
149
+ groups=groups,
150
+ bias=bias,
151
+ dilation=dilation,
152
+ **kwargs
153
+ )
154
+ self.tconv_im = nn.ConvTranspose2d(
155
+ in_channels,
156
+ out_channels,
157
+ kernel_size=kernel_size,
158
+ stride=stride,
159
+ padding=padding,
160
+ output_padding=output_padding,
161
+ groups=groups,
162
+ bias=bias,
163
+ dilation=dilation,
164
+ **kwargs
165
+ )
166
+
167
+ def forward(self, x: torch.Tensor):
168
+ """
169
+ :param x: torch.Tensor, shape: [b, c, h, w, 2]
170
+ :return:
171
+ """
172
+ real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1])
173
+ imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0])
174
+
175
+ output = torch.stack((real, imaginary), dim=-1)
176
+ return output
177
+
178
+
179
+ class ComplexBatchNorm2d(nn.Module):
180
+ def __init__(self,
181
+ num_features: int,
182
+ eps: float = 1e-5,
183
+ momentum: float = 0.1,
184
+ affine: bool = True,
185
+ track_running_stats: bool = True,
186
+ **kwargs
187
+ ):
188
+ super().__init__()
189
+ self.bn_re = nn.BatchNorm2d(
190
+ num_features=num_features,
191
+ momentum=momentum,
192
+ affine=affine,
193
+ eps=eps,
194
+ track_running_stats=track_running_stats,
195
+ **kwargs
196
+ )
197
+ self.bn_im = nn.BatchNorm2d(
198
+ num_features=num_features,
199
+ momentum=momentum,
200
+ affine=affine,
201
+ eps=eps,
202
+ track_running_stats=track_running_stats,
203
+ **kwargs
204
+ )
205
+
206
+ def forward(self, x: torch.Tensor):
207
+ real = self.bn_re(x[..., 0])
208
+ imag = self.bn_im(x[..., 1])
209
+
210
+ output = torch.stack((real, imag), dim=-1)
211
+ return output
212
+
213
+
214
+ def main():
215
+ # x = torch.rand(size=(1, 1, 32, 200, 2))
216
+ # fsmn = ComplexUniDeepFsmn(
217
+ # input_dim=32,
218
+ # hidden_size=64,
219
+ # )
220
+ # result = fsmn.forward(x)
221
+ # print(result.shape)
222
+
223
+ # x = torch.rand(size=(1, 32, 32, 200, 2))
224
+ # fsmn = ComplexUniDeepFsmnL1(
225
+ # input_dim=32,
226
+ # hidden_size=64,
227
+ # )
228
+ # result = fsmn.forward(x)
229
+ # print(result.shape)
230
+
231
+ # x = torch.rand(size=(1, 32, 200, 200, 2))
232
+ x = torch.rand(size=(1, 1, 320, 200, 2))
233
+ conv2d = ComplexConv2d(
234
+ in_channels=1,
235
+ out_channels=128,
236
+ kernel_size=(5, 2),
237
+ stride=(2, 1),
238
+ padding=(0, 1),
239
+ )
240
+ result = conv2d.forward(x)
241
+ print(result.shape)
242
+
243
+ # x = torch.rand(size=(1, 32, 200, 200, 2))
244
+ # x = torch.rand(size=(1, 64, 15, 2000, 2))
245
+ # tconv = ComplexConvTranspose2d(
246
+ # in_channels=64,
247
+ # out_channels=32,
248
+ # kernel_size=(3, 3),
249
+ # stride=(2, 1),
250
+ # padding=(0, 1),
251
+ # )
252
+ # result = tconv.forward(x)
253
+ # print(result.shape)
254
+ return
255
+
256
+
257
+ if __name__ == "__main__":
258
+ main()
toolbox/torchaudio/models/frcrn/configuration_frcrn.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://huggingface.co/spaces/alibabasglab/ClearVoice/blob/main/checkpoints/FRCRN_SE_16K/config.yaml
5
+ https://huggingface.co/spaces/alibabasglab/ClearVoice/blob/main/config/inference/FRCRN_SE_16K.yaml
6
+
7
+ """
8
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
9
+
10
+
11
+ class FRCRNConfig(PretrainedConfig):
12
+ def __init__(self,
13
+ num_gpus: int = -1,
14
+
15
+ lr: float = 0.001,
16
+ max_epochs: int = 100,
17
+ weight_decay: float = 0.00001,
18
+ clip_grad_norm: float = 10.,
19
+ seed: int = 1234,
20
+
21
+ sample_rate: int = 8000,
22
+ segment_size: int = 32000,
23
+ nfft: int = 512,
24
+ win_size: int = 512,
25
+ hop_size: int = 256,
26
+ win_type: str = "hann",
27
+
28
+ use_complex_networks: bool = True,
29
+ model_depth: int = 20,
30
+ model_complexity: int = 45,
31
+
32
+ num_workers: int = 4,
33
+ batch_size: int = 4,
34
+ **kwargs
35
+ ):
36
+ super(FRCRNConfig, self).__init__(**kwargs)
37
+ self.num_gpus = num_gpus
38
+
39
+ self.lr = lr
40
+ self.max_epochs = max_epochs
41
+ self.weight_decay = weight_decay
42
+ self.clip_grad_norm = clip_grad_norm
43
+ self.seed = seed
44
+
45
+ self.sample_rate = sample_rate
46
+ self.segment_size = segment_size
47
+ self.nfft = nfft
48
+ self.win_size = win_size
49
+ self.hop_size = hop_size
50
+ self.win_type = win_type
51
+
52
+ self.use_complex_networks = use_complex_networks
53
+ self.model_depth = model_depth
54
+ self.model_complexity = model_complexity
55
+
56
+ self.num_workers = num_workers
57
+ self.batch_size = batch_size
58
+
59
+
60
+ def main():
61
+ config = FRCRNConfig()
62
+ config.to_yaml_file("config.yaml")
63
+ return
64
+
65
+
66
+ if __name__ == "__main__":
67
+ main()
toolbox/torchaudio/models/frcrn/conv_stft.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py
5
+ """
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from scipy.signal import get_window
11
+
12
+
13
+ def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
14
+ if win_type == "None" or win_type is None:
15
+ window = np.ones(win_size)
16
+ else:
17
+ window = get_window(win_type, win_size, fftbins=True)**0.5
18
+
19
+ fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size]
20
+ real_kernel = np.real(fourier_basis)
21
+ image_kernel = np.imag(fourier_basis)
22
+ kernel = np.concatenate([real_kernel, image_kernel], 1).T
23
+
24
+ if inverse:
25
+ kernel = np.linalg.pinv(kernel).T
26
+
27
+ kernel = kernel * window
28
+ kernel = kernel[:, None, :]
29
+ result = (
30
+ torch.from_numpy(kernel.astype(np.float32)),
31
+ torch.from_numpy(window[None, :, None].astype(np.float32))
32
+ )
33
+ return result
34
+
35
+
36
+ class ConvSTFT(nn.Module):
37
+
38
+ def __init__(self,
39
+ nfft: int,
40
+ win_size: int,
41
+ hop_size: int,
42
+ win_type: str = "hamming",
43
+ feature_type: str = "real",
44
+ requires_grad: bool = False):
45
+ super(ConvSTFT, self).__init__()
46
+
47
+ if nfft is None:
48
+ self.nfft = int(2**np.ceil(np.log2(win_size)))
49
+ else:
50
+ self.nfft = nfft
51
+
52
+ kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type)
53
+ self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
54
+
55
+ self.win_size = win_size
56
+ self.hop_size = hop_size
57
+
58
+ self.stride = hop_size
59
+ self.dim = self.nfft
60
+ self.feature_type = feature_type
61
+
62
+ def forward(self, inputs: torch.Tensor):
63
+ if inputs.dim() == 2:
64
+ inputs = torch.unsqueeze(inputs, 1)
65
+
66
+ outputs = F.conv1d(inputs, self.weight, stride=self.stride)
67
+
68
+ if self.feature_type == "complex":
69
+ return outputs
70
+ else:
71
+ dim = self.dim // 2 + 1
72
+ real = outputs[:, :dim, :]
73
+ imag = outputs[:, dim:, :]
74
+ mags = torch.sqrt(real**2 + imag**2)
75
+ phase = torch.atan2(imag, real)
76
+ return mags, phase
77
+
78
+
79
+ class ConviSTFT(nn.Module):
80
+
81
+ def __init__(self,
82
+ win_size: int,
83
+ hop_size: int,
84
+ nfft: int = None,
85
+ win_type: str = "hamming",
86
+ feature_type: str = "real",
87
+ requires_grad: bool = False):
88
+ super(ConviSTFT, self).__init__()
89
+ if nfft is None:
90
+ self.nfft = int(2**np.ceil(np.log2(win_size)))
91
+ else:
92
+ self.nfft = nfft
93
+
94
+ kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
95
+ self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
96
+
97
+ self.win_size = win_size
98
+ self.hop_size = hop_size
99
+ self.win_type = win_type
100
+
101
+ self.stride = hop_size
102
+ self.dim = self.nfft
103
+ self.feature_type = feature_type
104
+
105
+ self.register_buffer("window", window)
106
+ self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
107
+
108
+ def forward(self,
109
+ inputs: torch.Tensor,
110
+ phase: torch.Tensor = None):
111
+ """
112
+ :param inputs: torch.Tensor, shape: [b, n+2, t] (complex spec) or [b, n//2+1, t] (mags)
113
+ :param phase: torch.Tensor, shape: [b, n//2+1, t]
114
+ :return:
115
+ """
116
+ if phase is not None:
117
+ real = inputs * torch.cos(phase)
118
+ imag = inputs * torch.sin(phase)
119
+ inputs = torch.cat([real, imag], 1)
120
+ outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
121
+
122
+ # this is from torch-stft: https://github.com/pseeth/torch-stft
123
+ t = self.window.repeat(1, 1, inputs.size(-1))**2
124
+ coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
125
+ outputs = outputs / (coff + 1e-8)
126
+ return outputs
127
+
128
+
129
+ def main():
130
+ stft = ConvSTFT(win_size=512, hop_size=200, feature_type="complex")
131
+ istft = ConviSTFT(win_size=512, hop_size=200, feature_type="complex")
132
+
133
+ mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
134
+
135
+ spec = stft.forward(mixture)
136
+ # shape: [batch_size, freq_bins, time_steps]
137
+ print(spec.shape)
138
+
139
+ waveform = istft.forward(spec)
140
+ # shape: [batch_size, channels, num_samples]
141
+ print(waveform.shape)
142
+
143
+ return
144
+
145
+
146
+ if __name__ == "__main__":
147
+ main()
toolbox/torchaudio/models/frcrn/modeling_frcrn.py CHANGED
@@ -2,9 +2,324 @@
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://arxiv.org/abs/2206.07293
 
 
 
 
5
  """
6
- from modelscope.models.audio.ans.frcrn import FRCRN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  if __name__ == "__main__":
10
- pass
 
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://arxiv.org/abs/2206.07293
5
+
6
+ https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/frcrn.py
7
+ https://huggingface.co/spaces/alibabasglab/ClearVoice/blob/main/models/frcrn_se/frcrn.py
8
+
9
  """
10
+ import os
11
+ from typing import Optional, Union
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.nn import functional as F
16
+
17
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
18
+ from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
19
+ from toolbox.torchaudio.models.frcrn.conv_stft import ConviSTFT, ConvSTFT
20
+ from toolbox.torchaudio.models.frcrn.unet import UNet
21
+
22
+
23
+ class FRCRN(nn.Module):
24
+ """ Frequency Recurrent CRN """
25
+
26
+ def __init__(self,
27
+ use_complex_networks: bool = True,
28
+ model_complexity: int = 45,
29
+ model_depth: int = 14,
30
+ padding_mode: str = "zeros",
31
+ nfft: int = 640,
32
+ win_size: int = 640,
33
+ hop_size: int = 320,
34
+ win_type: str = "hann",
35
+ ):
36
+ """
37
+ :param use_complex_networks: bool, Whether to use complex networks.
38
+ :param model_complexity: int, define the model complexity with the number of layers
39
+ :param model_depth: int, Only two options are available : 10, 20
40
+ :param padding_mode: str, Encoder's convolution filter. 'zeros', 'reflect'
41
+ :param nfft: int, number of Short Time Fourier Transform (STFT) points
42
+ :param win_size: int, length of window used for defining one frame of sample points
43
+ :param hop_size: int, length of window shifting (equivalent to hop_size)
44
+ :param win_type: str, windowing type used in STFT, eg. 'hanning', 'hamming'
45
+ """
46
+ super().__init__()
47
+ self.freq_bins = nfft // 2 + 1
48
+
49
+ self.nfft = nfft
50
+ self.win_size = win_size
51
+ self.hop_size = hop_size
52
+ self.win_type = win_type
53
+
54
+ self.stft = ConvSTFT(
55
+ nfft=self.nfft,
56
+ win_size=self.win_size,
57
+ hop_size=self.hop_size,
58
+ feature_type="complex",
59
+ requires_grad=False
60
+ )
61
+ self.istft = ConviSTFT(
62
+ nfft=self.nfft,
63
+ win_size=self.win_size,
64
+ hop_size=self.hop_size,
65
+ win_type=self.win_type,
66
+ feature_type="complex",
67
+ requires_grad=False
68
+ )
69
+ self.unet = UNet(
70
+ in_channels=1,
71
+ use_complex_networks=use_complex_networks,
72
+ model_complexity=model_complexity,
73
+ model_depth=model_depth,
74
+ padding_mode=padding_mode
75
+ )
76
+ self.unet2 = UNet(
77
+ in_channels=1,
78
+ use_complex_networks=use_complex_networks,
79
+ model_complexity=model_complexity,
80
+ model_depth=model_depth,
81
+ padding_mode=padding_mode
82
+ )
83
+
84
+ def forward(self, noisy: torch.Tensor):
85
+ """
86
+ :param noisy: torch.Tensor, shape: [b, n_samples] or [b, c, n_samples]
87
+ :return:
88
+ """
89
+ if noisy.dim() == 2:
90
+ noisy = torch.unsqueeze(noisy, dim=1)
91
+ _, _, n_samples = noisy.shape
92
+ remainder = (n_samples - self.win_size) % self.hop_size
93
+ if remainder > 0:
94
+ n_samples_pad = self.hop_size - remainder
95
+ noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0)
96
+
97
+ # [batch_size, freq_bins * 2, time_steps]
98
+ cmp_spec = self.stft.forward(noisy)
99
+ # [batch_size, 1, freq_bins * 2, time_steps]
100
+ cmp_spec = torch.unsqueeze(cmp_spec, 1)
101
+
102
+ # [batch_size, 2, freq_bins, time_steps]
103
+ cmp_spec = torch.cat([
104
+ cmp_spec[:, :, :self.freq_bins, :],
105
+ cmp_spec[:, :, self.freq_bins:, :],
106
+ ], dim=1)
107
+
108
+ # [batch_size, 2, freq_bins, time_steps, 1]
109
+ cmp_spec = torch.unsqueeze(cmp_spec, dim=4)
110
+
111
+ cmp_spec = torch.transpose(cmp_spec, 1, 4)
112
+ # [batch_size, 1, freq_bins, time_steps, 2]
113
+
114
+ unet1_out = self.unet.forward(cmp_spec)
115
+ cmp_mask1 = torch.tanh(unet1_out)
116
+ unet2_out = self.unet2.forward(unet1_out)
117
+ cmp_mask2 = torch.tanh(unet2_out)
118
+
119
+ # est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask1)
120
+
121
+ cmp_mask2 = cmp_mask2 + cmp_mask1
122
+ est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2)
123
+ # est_wav shape: [b, n_samples]
124
+
125
+ est_wav = est_wav[:, :n_samples]
126
+ return est_spec, est_wav, est_mask
127
+
128
+ def apply_mask(self,
129
+ cmp_spec: torch.Tensor,
130
+ cmp_mask: torch.Tensor,
131
+ ):
132
+ """
133
+ :param cmp_spec: torch.Tensor, shape: [batch_size, 1, freq_bins, time_steps, 2]
134
+ :param cmp_mask: torch.Tensor, shape: [batch_size, 1, freq_bins, time_steps, 2]
135
+ :return:
136
+ """
137
+ est_spec = torch.cat(
138
+ tensors=[
139
+ cmp_spec[..., 0] * cmp_mask[..., 0] - cmp_spec[..., 1] * cmp_mask[..., 1],
140
+ cmp_spec[..., 0] * cmp_mask[..., 1] + cmp_spec[..., 1] * cmp_mask[..., 0]
141
+ ], dim=1
142
+ )
143
+ # est_spec shape: [b, 2, n//2+1, t]
144
+ est_spec = torch.cat(tensors=[est_spec[:, 0, :, :], est_spec[:, 1, :, :]], dim=1)
145
+ # est_spec shape: [b, n+2, t]
146
+
147
+ # cmp_mask shape: [b, 1, n//2+1, t, 2]
148
+ cmp_mask = torch.squeeze(cmp_mask, dim=1)
149
+ # cmp_mask shape: [b, n//2+1, t, 2]
150
+ cmp_mask = torch.cat(tensors=[cmp_mask[:, :, :, 0], cmp_mask[:, :, :, 1]], dim=1)
151
+ # cmp_mask shape: [b, n+2, t]
152
+
153
+ # est_spec shape: [b, n+2, t]
154
+ est_wav = self.istft(est_spec)
155
+ # est_wav shape: [b, 1, n_samples]
156
+ est_wav = torch.squeeze(est_wav, 1)
157
+ # est_wav shape: [b, n_samples]
158
+ return est_spec, est_wav, cmp_mask
159
+
160
+ def get_params(self, weight_decay=0.0):
161
+ """
162
+ 为可训练参数配置 weight_decay (权重衰减) 的作用是实现 L2 正则化。
163
+ 1. 防止过拟合: 通过向损失函数添加参数的 L2 范数 (平方和) 作为惩罚项, weight_decay 会限制模型权重的大小.
164
+ 这使得模型倾向于学习更小的权重值, 降低对训练数据的过度敏感, 从而提高泛化能力.
165
+ 2. 控制模型复杂度: 权重衰减直接作用于优化过程, 在梯度更新时对权重进行衰减,
166
+ 公式: weight = weight - lr * (gradient + weight_decay * weight).
167
+ 这相当于在梯度下降中额外引入了一个与当前权重值成正比的衰减力, 抑制权重快速增长.
168
+ 3. 与优化器的具体实现相关
169
+ 在 SGD 等传统优化器中, weight_decay 直接等价于 L2 正则化.
170
+ 在 Adam 优化器中, 权重衰减的实现与参数更新耦合, 可能因学习率调整而效果减弱.
171
+ 在 AdamW 优化器改进了这一点, 将权重衰减与学习率解耦, 使其更符合 L2 正则化的理论效果.
172
+
173
+ 注意:
174
+ 值过大会导致欠拟合, 过小则正则化效果弱, 常用范围是 1e-4到 1e-2.
175
+ 某些场景 (如 BatchNorm 层) 可能需要通过参数分组对不同层设置不同的 weight_decay.
176
+ :param weight_decay:
177
+ :return:
178
+ """
179
+ weights, biases = [], []
180
+ for name, param in self.named_parameters():
181
+ if "bias" in name:
182
+ biases += [param]
183
+ else:
184
+ weights += [param]
185
+
186
+ params = [{
187
+ 'params': weights,
188
+ 'weight_decay': weight_decay,
189
+ }, {
190
+ 'params': biases,
191
+ 'weight_decay': 0.0,
192
+ }]
193
+ return params
194
+
195
+ def mag_pha_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
196
+ """
197
+
198
+ :param est_mask: torch.Tensor, shape: [b, n+2, t]
199
+ :param clean:
200
+ :param noisy:
201
+ :return:
202
+ """
203
+ clean_stft = self.stft(clean)
204
+ clean_re = clean_stft[:, :self.freq_bins, :]
205
+ clean_im = clean_stft[:, self.freq_bins:, :]
206
+
207
+ noisy_stft = self.stft(noisy)
208
+ noisy_re = noisy_stft[:, :self.freq_bins, :]
209
+ noisy_im = noisy_stft[:, self.freq_bins:, :]
210
+
211
+ noisy_power = noisy_re ** 2 + noisy_im ** 2
212
+
213
+ sr = clean_re
214
+ yr = noisy_re
215
+ si = clean_im
216
+ yi = noisy_im
217
+ y_pow = noisy_power
218
+ # (Sr * Yr + Si * Yi) / (Y_pow + 1e-8)
219
+ gth_mask_re = (sr * yr + si * yi) / (y_pow + self.eps)
220
+ # (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)
221
+ gth_mask_im = (sr * yr - si * yi) / (y_pow + self.eps)
222
+
223
+ gth_mask_re[gth_mask_re > 2] = 1
224
+ gth_mask_re[gth_mask_re < -2] = -1
225
+ gth_mask_im[gth_mask_im > 2] = 1
226
+ gth_mask_im[gth_mask_im < -2] = -1
227
+
228
+ mask_re = est_mask[:, :self.freq_bins, :]
229
+ mask_im = est_mask[:, self.freq_bins:, :]
230
+
231
+ amp_loss = F.mse_loss(gth_mask_re, mask_re)
232
+ phase_loss = F.mse_loss(gth_mask_im, mask_im)
233
+
234
+ return amp_loss, phase_loss
235
+
236
+
237
+ MODEL_FILE = "model.pt"
238
+
239
+
240
+ class FRCRNPretrainedModel(FRCRN):
241
+ def __init__(self,
242
+ config: FRCRNConfig,
243
+ ):
244
+ super(FRCRNPretrainedModel, self).__init__(
245
+ use_complex_networks=config.use_complex_networks,
246
+ model_complexity=config.model_complexity,
247
+ model_depth=config.model_depth,
248
+ nfft=config.nfft,
249
+ win_size=config.win_size,
250
+ hop_size=config.hop_size,
251
+ win_type=config.win_type,
252
+ )
253
+ self.config = config
254
+
255
+ @classmethod
256
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
257
+ config = FRCRNConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
258
+
259
+ model = cls(config)
260
+
261
+ if os.path.isdir(pretrained_model_name_or_path):
262
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
263
+ else:
264
+ ckpt_file = pretrained_model_name_or_path
265
+
266
+ with open(ckpt_file, "rb") as f:
267
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
268
+ model.load_state_dict(state_dict, strict=True)
269
+ return model
270
+
271
+ def save_pretrained(self,
272
+ save_directory: Union[str, os.PathLike],
273
+ state_dict: Optional[dict] = None,
274
+ ):
275
+
276
+ model = self
277
+
278
+ if state_dict is None:
279
+ state_dict = model.state_dict()
280
+
281
+ os.makedirs(save_directory, exist_ok=True)
282
+
283
+ # save state dict
284
+ model_file = os.path.join(save_directory, MODEL_FILE)
285
+ torch.save(state_dict, model_file)
286
+
287
+ # save config
288
+ config_file = os.path.join(save_directory, CONFIG_FILE)
289
+ self.config.to_yaml_file(config_file)
290
+ return save_directory
291
+
292
+
293
+ def main():
294
+ # model = FRCRN(
295
+ # use_complex_networks=True,
296
+ # model_complexity=45,
297
+ # model_depth=14,
298
+ # padding_mode="zeros",
299
+ # nfft=512,
300
+ # win_size=400,
301
+ # hop_size=200,
302
+ # win_type="hann",
303
+ # )
304
+ model = FRCRN(
305
+ use_complex_networks=True,
306
+ model_complexity=45,
307
+ model_depth=14,
308
+ padding_mode="zeros",
309
+ nfft=640,
310
+ win_size=640,
311
+ hop_size=320,
312
+ win_type="hann",
313
+ )
314
+ mixture = torch.rand(size=(1, 8000), dtype=torch.float32)
315
+
316
+ est_spec, est_wav, est_mask = model.forward(mixture)
317
+ print(est_spec.shape)
318
+ print(est_wav.shape)
319
+ print(est_mask.shape)
320
+
321
+ return
322
 
323
 
324
  if __name__ == "__main__":
325
+ main()
toolbox/torchaudio/models/frcrn/unet.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Union, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from toolbox.torchaudio.models.frcrn import complex_nn
9
+
10
+
11
+ class SELayer(nn.Module):
12
+ def __init__(self, channels: int, reduction: int = 16):
13
+ super(SELayer, self).__init__()
14
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
15
+
16
+ self.fc_r = nn.Sequential(
17
+ nn.Linear(channels, channels // reduction),
18
+ nn.ReLU(inplace=True),
19
+ nn.Linear(channels // reduction, channels),
20
+ nn.Sigmoid()
21
+ )
22
+ self.fc_i = nn.Sequential(
23
+ nn.Linear(channels, channels // reduction),
24
+ nn.ReLU(inplace=True),
25
+ nn.Linear(channels // reduction, channels),
26
+ nn.Sigmoid()
27
+ )
28
+
29
+ def forward(self, x: torch.Tensor):
30
+ b, c, _, _, _ = x.size()
31
+ x_r = self.avg_pool(x[:, :, :, :, 0]).view(b, c)
32
+ x_i = self.avg_pool(x[:, :, :, :, 1]).view(b, c)
33
+
34
+ y_r = self.fc_r(x_r).view(b, c, 1, 1, 1) - self.fc_i(x_i).view(b, c, 1, 1, 1)
35
+ y_i = self.fc_r(x_i).view(b, c, 1, 1, 1) + self.fc_i(x_r).view(b, c, 1, 1, 1)
36
+
37
+ y = torch.cat(tensors=[y_r, y_i], dim=4)
38
+ return x * y
39
+
40
+
41
+ class Encoder(nn.Module):
42
+ def __init__(self,
43
+ in_channels: int,
44
+ out_channels: int,
45
+ kernel_size: Union[int, Tuple[int, int]],
46
+ stride: Union[int, Tuple[int, int]],
47
+ padding: Union[int, Tuple[int, int]] = None,
48
+ use_complex_networks: bool = False,
49
+ padding_mode: str = "zeros"
50
+ ):
51
+ super().__init__()
52
+ if padding is None:
53
+ padding = [(k - 1) // 2 for k in kernel_size] # 'SAME' padding
54
+
55
+ if use_complex_networks:
56
+ conv = complex_nn.ComplexConv2d
57
+ bn = complex_nn.ComplexBatchNorm2d
58
+ else:
59
+ conv = nn.Conv2d
60
+ bn = nn.BatchNorm2d
61
+
62
+ self.conv = conv(
63
+ in_channels,
64
+ out_channels,
65
+ kernel_size=kernel_size,
66
+ stride=stride,
67
+ padding=padding,
68
+ padding_mode=padding_mode
69
+ )
70
+ self.bn = bn(out_channels)
71
+ self.relu = nn.LeakyReLU(inplace=True)
72
+
73
+ def forward(self, x: torch.Tensor):
74
+ x = self.conv(x)
75
+ x = self.bn(x)
76
+ x = self.relu(x)
77
+ return x
78
+
79
+
80
+ class Decoder(nn.Module):
81
+ def __init__(self,
82
+ in_channels: int,
83
+ out_channels: int,
84
+ kernel_size: Union[int, Tuple[int, int]],
85
+ stride: Union[int, Tuple[int, int]],
86
+ padding: Union[int, Tuple[int, int]] = (0, 0),
87
+ use_complex_networks: bool = False,
88
+ ):
89
+ super().__init__()
90
+ if use_complex_networks:
91
+ tconv = complex_nn.ComplexConvTranspose2d
92
+ bn = complex_nn.ComplexBatchNorm2d
93
+ else:
94
+ tconv = nn.ConvTranspose2d
95
+ bn = nn.BatchNorm2d
96
+
97
+ self.transconv = tconv(
98
+ in_channels,
99
+ out_channels,
100
+ kernel_size=kernel_size,
101
+ stride=stride,
102
+ padding=padding
103
+ )
104
+ self.bn = bn(out_channels)
105
+ self.relu = nn.LeakyReLU(inplace=True)
106
+
107
+ def forward(self, x):
108
+ x = self.transconv(x)
109
+ x = self.bn(x)
110
+ x = self.relu(x)
111
+ return x
112
+
113
+
114
+ class UNetConfig14(object):
115
+ """
116
+ inputs x shape: [1, 1, 321, 2000, 2]
117
+
118
+ sample rate: 16000
119
+ nfft: 640
120
+ win_size: 640
121
+ hop_size: 320 (200ms)
122
+ """
123
+ def __init__(self, in_channels: int):
124
+ self.enc_channels = [in_channels, 128, 128, 128, 128, 128, 128, 128]
125
+ self.enc_kernel_sizes = [(5, 2), (5, 2), (5, 2), (5, 2), (5, 2), (5, 2), (2, 2)]
126
+ self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
127
+ self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
128
+
129
+ self.dec_channels = [64, 128, 128, 128, 128, 128, 128, 1]
130
+ self.dec_kernel_sizes = [(2, 2), (5, 2), (5, 2), (5, 2), (6, 2), (5, 2), (5, 2)]
131
+ self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
132
+ self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
133
+
134
+
135
+ class UNetConfig10(object):
136
+ """
137
+ inputs x shape: [1, 1, 65, 200, 2]
138
+
139
+ sample rate: 8000
140
+ nfft: 128
141
+ win_size: 128
142
+ hop_size: 64 (8ms)
143
+
144
+ """
145
+ def __init__(self, in_channels: int):
146
+ self.enc_channels = [in_channels, 16, 32, 64, 128, 256]
147
+ self.enc_kernel_sizes = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
148
+ self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
149
+ self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
150
+
151
+ self.dec_channels = [128, 128, 64, 32, 16, 1]
152
+ self.dec_kernel_sizes = [(3, 3), (3, 3), (3, 3), (4, 3), (3, 3)]
153
+ self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
154
+ self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
155
+
156
+
157
+ class UNetConfig20(object):
158
+ """
159
+ inputs x shape: [1, 1, 257, 2000, 2]
160
+
161
+ sample rate: 8000
162
+ nfft: 512
163
+ win_size: 512
164
+ hop_size: 256 (32ms)
165
+
166
+ """
167
+ def __init__(self, in_channels: int, model_complexity: int):
168
+ self.enc_channels = [
169
+ in_channels,
170
+ model_complexity, model_complexity,
171
+ model_complexity * 2, model_complexity * 2,
172
+ model_complexity * 2, model_complexity * 2,
173
+ model_complexity * 2, model_complexity * 2,
174
+ model_complexity * 2,
175
+ 128
176
+ ]
177
+
178
+ self.enc_kernel_sizes = [(7, 1), (1, 7), (6, 4), (7, 5), (5, 3),
179
+ (5, 3), (5, 3), (5, 3), (5, 3), (5, 3)]
180
+
181
+ self.enc_strides = [(1, 1), (1, 1), (2, 2), (2, 1), (2, 2),
182
+ (2, 1), (2, 2), (2, 1), (2, 2), (2, 1)]
183
+
184
+ self.enc_paddings = [
185
+ (3, 0),
186
+ (0, 3),
187
+ None, # (0, 2),
188
+ None,
189
+ None, # (3,1),
190
+ None, # (3,1),
191
+ None, # (1,2),
192
+ None,
193
+ None,
194
+ None
195
+ ]
196
+
197
+ self.dec_channels = [
198
+ 64,
199
+ model_complexity * 2,
200
+ model_complexity * 2, model_complexity * 2,
201
+ model_complexity * 2, model_complexity * 2,
202
+ model_complexity * 2, model_complexity * 2,
203
+ model_complexity, model_complexity,
204
+ 1
205
+ ]
206
+
207
+ self.dec_kernel_sizes = [(4, 3), (4, 2), (4, 3), (4, 2), (4, 3),
208
+ (4, 2), (6, 3), (7, 4), (1, 7), (7, 1)]
209
+
210
+ self.dec_strides = [(2, 1), (2, 2), (2, 1), (2, 2), (2, 1),
211
+ (2, 2), (2, 1), (2, 2), (1, 1), (1, 1)]
212
+
213
+ self.dec_paddings = [(1, 1), (1, 0), (1, 1), (1, 0), (1, 1),
214
+ (1, 0), (2, 1), (2, 1), (0, 3), (3, 0)]
215
+
216
+
217
+ class UNet(nn.Module):
218
+ def __init__(self,
219
+ in_channels: int = 1,
220
+ use_complex_networks: bool = False,
221
+ model_complexity: int = 45,
222
+ model_depth: int = 20,
223
+ padding_mode: str = "zeros"
224
+ ):
225
+ super().__init__()
226
+ if use_complex_networks:
227
+ model_complexity = int(model_complexity // 1.414)
228
+
229
+ # config
230
+ if model_depth == 14:
231
+ config = UNetConfig14(in_channels)
232
+ elif model_depth == 10:
233
+ config = UNetConfig10(in_channels)
234
+ elif model_depth == 20:
235
+ config = UNetConfig20(in_channels, model_complexity)
236
+ else:
237
+ raise AssertionError(f"Unknown model depth : {model_depth}")
238
+
239
+ self.model_length = model_depth // 2
240
+
241
+ self.fsmn = complex_nn.ComplexUniDeepFsmn(
242
+ config.enc_channels[-1],
243
+ config.enc_channels[-1]
244
+ )
245
+
246
+ # go down
247
+ self.encoder_layers = nn.ModuleList(modules=[])
248
+ for i in range(self.model_length):
249
+ encoder_layer = nn.Sequential(
250
+ complex_nn.ComplexUniDeepFsmnL1(
251
+ config.enc_channels[i],
252
+ config.enc_channels[i]
253
+ )
254
+ if i != 0 else nn.Identity(),
255
+ Encoder(
256
+ config.enc_channels[i],
257
+ config.enc_channels[i + 1],
258
+ kernel_size=config.enc_kernel_sizes[i],
259
+ stride=config.enc_strides[i],
260
+ padding=config.enc_paddings[i],
261
+ use_complex_networks=use_complex_networks,
262
+ padding_mode=padding_mode
263
+ ),
264
+ SELayer(config.enc_channels[i + 1], reduction=8)
265
+ )
266
+ self.encoder_layers.append(encoder_layer)
267
+
268
+ self.decoder_layers = nn.ModuleList(modules=[])
269
+ for i in range(self.model_length):
270
+ decoder_layer = nn.Sequential(
271
+ Decoder(
272
+ config.dec_channels[i] * 2,
273
+ config.dec_channels[i + 1],
274
+ kernel_size=config.dec_kernel_sizes[i],
275
+ stride=config.dec_strides[i],
276
+ padding=config.dec_paddings[i],
277
+ use_complex_networks=use_complex_networks
278
+ ),
279
+ complex_nn.ComplexUniDeepFsmnL1(
280
+ config.dec_channels[i + 1],
281
+ config.dec_channels[i + 1]
282
+ )
283
+ if i < (self.model_length - 1) else nn.Identity(),
284
+ SELayer(
285
+ config.dec_channels[i + 1],
286
+ reduction=8
287
+ )
288
+ if i < (self.model_length - 2) else nn.Identity()
289
+ )
290
+ self.decoder_layers.append(decoder_layer)
291
+
292
+ if use_complex_networks:
293
+ conv = complex_nn.ComplexConv2d
294
+ else:
295
+ conv = nn.Conv2d
296
+
297
+ self.linear = conv(
298
+ in_channels=config.dec_channels[-1],
299
+ out_channels=1,
300
+ kernel_size=1,
301
+ )
302
+
303
+ def forward(self, inputs: torch.Tensor):
304
+ """
305
+ :param inputs: torch.Tensor, shape: [b, c, f, t, 2]
306
+ :return:
307
+ """
308
+ x = inputs
309
+
310
+ # go down
311
+ xs = list()
312
+ xs_se = list()
313
+ xs_se.append(x)
314
+ for encoder_layer in self.encoder_layers:
315
+ xs.append(x)
316
+ # print(f"x: {x.shape}")
317
+ x = encoder_layer.forward(x)
318
+ # print(f"x: {x.shape}")
319
+ xs_se.append(x)
320
+
321
+ # x shape: [b, c, 1, t', 2]
322
+ x = self.fsmn.forward(x)
323
+ # x shape: [b, c, 1, t', 2]
324
+ # print(f"fsmn")
325
+
326
+ p = x
327
+ for i, decoder_layers in enumerate(self.decoder_layers):
328
+ # print(f"x: {x.shape}")
329
+ p = decoder_layers.forward(p)
330
+ # print(f"p: {p.shape}")
331
+ if i == self.model_length - 1:
332
+ break
333
+ p = torch.cat(tensors=[p, xs_se[self.model_length - 1 - i]], dim=1)
334
+
335
+ # cmp_spec: [1, 1, 321, 200, 2]
336
+ # cmp_spec: [1, 1, 513, 200, 2]
337
+ cmp_spec = self.linear.forward(p)
338
+ return cmp_spec
339
+
340
+
341
+ def main():
342
+ # [batch_size, 1, freq_bins, time_steps, 2]
343
+ x = torch.rand(size=(1, 1, 257, 2000, 2))
344
+ # x = torch.rand(size=(1, 1, 256, 2000, 2))
345
+ # x = torch.rand(size=(1, 1, 255, 2000, 2))
346
+ unet = UNet(
347
+ in_channels=1,
348
+ model_complexity=45,
349
+ model_depth=20,
350
+ use_complex_networks=True
351
+ )
352
+ print(unet)
353
+ result = unet.forward(x)
354
+ print(result.shape)
355
+ return
356
+
357
+
358
+ if __name__ == "__main__":
359
+ main()
toolbox/torchaudio/models/frcrn/uni_deep_fsmn.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/layers/uni_deep_fsmn.py
5
+ https://huggingface.co/spaces/alibabasglab/ClearVoice/blob/main/models/mossformer2_se/fsmn.py
6
+ """
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class UniDeepFsmn(nn.Module):
13
+
14
+ def __init__(self,
15
+ input_dim: int,
16
+ hidden_size: int,
17
+ lorder: int = 1,
18
+ ):
19
+ super(UniDeepFsmn, self).__init__()
20
+ self.input_dim = input_dim
21
+ self.hidden_size = hidden_size
22
+ self.lorder = lorder
23
+
24
+ self.linear = nn.Linear(input_dim, hidden_size)
25
+ self.project = nn.Linear(hidden_size, input_dim, bias=False)
26
+ self.conv1 = nn.Conv2d(
27
+ input_dim,
28
+ input_dim,
29
+ kernel_size=(lorder, 1),
30
+ stride=(1, 1),
31
+ groups=input_dim,
32
+ bias=False
33
+ )
34
+
35
+ def forward(self, inputs: torch.Tensor):
36
+ """
37
+ :param inputs: torch.Tensor, shape: [b, t, h]
38
+ :return: torch.Tensor, shape: [b, t, h]
39
+ """
40
+ x = F.relu(self.linear(inputs))
41
+ x = self.project(x)
42
+ x = torch.unsqueeze(x, 1)
43
+ # x shape: [b, 1, t, h]
44
+
45
+ x = x.permute(0, 3, 2, 1)
46
+ # x shape: [b, h, t, 1]
47
+ y = F.pad(x, [0, 0, self.lorder - 1, 0])
48
+
49
+ x = x + self.conv1(y)
50
+ x = x.permute(0, 3, 2, 1)
51
+ # x shape: [b, 1, t, h]
52
+ x = x.squeeze()
53
+
54
+ result = inputs + x
55
+ return result
56
+
57
+
58
+ def main():
59
+ x = torch.rand(size=(1, 200, 32))
60
+ fsmn = UniDeepFsmn(
61
+ input_dim=32,
62
+ hidden_size=64,
63
+ lorder=3,
64
+ )
65
+ result = fsmn.forward(x)
66
+ print(result.shape)
67
+ return
68
+
69
+
70
+ if __name__ == "__main__":
71
+ main()