HoneyTian commited on
Commit
2171fed
·
1 Parent(s): a645af7

add dfnet2

Browse files
examples/{clean_unet_aishell → clean_unet}/run.sh RENAMED
File without changes
examples/{clean_unet_aishell → clean_unet}/step_1_prepare_data.py RENAMED
File without changes
examples/{clean_unet_aishell → clean_unet}/step_2_train_model.py RENAMED
File without changes
examples/{clean_unet_aishell → clean_unet}/step_3_evaluation.py RENAMED
File without changes
examples/{clean_unet_aishell → clean_unet}/yaml/config.yaml RENAMED
File without changes
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -346,6 +346,7 @@ def main():
346
  # evaluation
347
  step_idx += 1
348
  if step_idx % config.eval_steps == 0:
 
349
  with torch.no_grad():
350
  torch.cuda.empty_cache()
351
 
@@ -499,6 +500,7 @@ def main():
499
  # early stop
500
  if early_stop_flag:
501
  break
 
502
 
503
  return
504
 
 
346
  # evaluation
347
  step_idx += 1
348
  if step_idx % config.eval_steps == 0:
349
+ model.eval()
350
  with torch.no_grad():
351
  torch.cuda.empty_cache()
352
 
 
500
  # early stop
501
  if early_stop_flag:
502
  break
503
+ model.train()
504
 
505
  return
506
 
examples/conv_tasnet_gan/run.sh DELETED
@@ -1,156 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- : <<'END'
4
-
5
-
6
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
- --max_epochs 400
10
-
11
-
12
- END
13
-
14
-
15
- # params
16
- system_version="windows";
17
- verbose=true;
18
- stage=0 # start from 0 if you need to start from data preparation
19
- stop_stage=9
20
-
21
- work_dir="$(pwd)"
22
- file_folder_name=file_folder_name
23
- final_model_name=final_model_name
24
- config_file="yaml/config.yaml"
25
- discriminator_config_file="yaml/discriminator_config.yaml"
26
- limit=10
27
-
28
- noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
29
- speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
30
-
31
- max_count=10000000
32
-
33
- nohup_name=nohup.out
34
-
35
- # model params
36
- batch_size=64
37
- max_epochs=200
38
- save_top_k=10
39
- patience=5
40
-
41
-
42
- # parse options
43
- while true; do
44
- [ -z "${1:-}" ] && break; # break if there are no arguments
45
- case "$1" in
46
- --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
47
- eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
48
- old_value="(eval echo \\$$name)";
49
- if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
50
- was_bool=true;
51
- else
52
- was_bool=false;
53
- fi
54
-
55
- # Set the variable to the right value-- the escaped quotes make it work if
56
- # the option had spaces, like --cmd "queue.pl -sync y"
57
- eval "${name}=\"$2\"";
58
-
59
- # Check that Boolean-valued arguments are really Boolean.
60
- if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
61
- echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
62
- exit 1;
63
- fi
64
- shift 2;
65
- ;;
66
-
67
- *) break;
68
- esac
69
- done
70
-
71
- file_dir="${work_dir}/${file_folder_name}"
72
- final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
73
- evaluation_audio_dir="${file_dir}/evaluation_audio"
74
-
75
- train_dataset="${file_dir}/train.jsonl"
76
- valid_dataset="${file_dir}/valid.jsonl"
77
-
78
- $verbose && echo "system_version: ${system_version}"
79
- $verbose && echo "file_folder_name: ${file_folder_name}"
80
-
81
- if [ $system_version == "windows" ]; then
82
- alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
83
- elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
84
- #source /data/local/bin/nx_denoise/bin/activate
85
- alias python3='/data/local/bin/nx_denoise/bin/python3'
86
- fi
87
-
88
-
89
- if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
90
- $verbose && echo "stage 1: prepare data"
91
- cd "${work_dir}" || exit 1
92
- python3 step_1_prepare_data.py \
93
- --file_dir "${file_dir}" \
94
- --noise_dir "${noise_dir}" \
95
- --speech_dir "${speech_dir}" \
96
- --train_dataset "${train_dataset}" \
97
- --valid_dataset "${valid_dataset}" \
98
- --max_count "${max_count}" \
99
-
100
- fi
101
-
102
-
103
- if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
104
- $verbose && echo "stage 2: train model"
105
- cd "${work_dir}" || exit 1
106
- python3 step_2_train_model.py \
107
- --train_dataset "${train_dataset}" \
108
- --valid_dataset "${valid_dataset}" \
109
- --serialization_dir "${file_dir}" \
110
- --config_file "${config_file}" \
111
- --discriminator_config_file "${discriminator_config_file}" \
112
-
113
- fi
114
-
115
-
116
- if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
117
- $verbose && echo "stage 3: test model"
118
- cd "${work_dir}" || exit 1
119
- python3 step_3_evaluation.py \
120
- --valid_dataset "${valid_dataset}" \
121
- --model_dir "${file_dir}/best" \
122
- --evaluation_audio_dir "${evaluation_audio_dir}" \
123
- --limit "${limit}" \
124
-
125
- fi
126
-
127
-
128
- if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
129
- $verbose && echo "stage 4: collect files"
130
- cd "${work_dir}" || exit 1
131
-
132
- mkdir -p ${final_model_dir}
133
-
134
- cp "${file_dir}/best"/* "${final_model_dir}"
135
- cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
136
-
137
- cd "${final_model_dir}/.." || exit 1;
138
-
139
- if [ -e "${final_model_name}.zip" ]; then
140
- rm -rf "${final_model_name}_backup.zip"
141
- mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
142
- fi
143
-
144
- zip -r "${final_model_name}.zip" "${final_model_name}"
145
- rm -rf "${final_model_name}"
146
-
147
- fi
148
-
149
-
150
- if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
151
- $verbose && echo "stage 5: clear file_dir"
152
- cd "${work_dir}" || exit 1
153
-
154
- rm -rf "${file_dir}";
155
-
156
- fi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/conv_tasnet_gan/step_1_prepare_data.py DELETED
@@ -1,162 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import argparse
4
- import json
5
- import os
6
- from pathlib import Path
7
- import random
8
- import sys
9
-
10
- pwd = os.path.abspath(os.path.dirname(__file__))
11
- sys.path.append(os.path.join(pwd, "../../"))
12
-
13
- import librosa
14
- import numpy as np
15
- from tqdm import tqdm
16
-
17
-
18
- def get_args():
19
- parser = argparse.ArgumentParser()
20
- parser.add_argument("--file_dir", default="./", type=str)
21
-
22
- parser.add_argument(
23
- "--noise_dir",
24
- default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
25
- type=str
26
- )
27
- parser.add_argument(
28
- "--speech_dir",
29
- default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
30
- type=str
31
- )
32
-
33
- parser.add_argument("--train_dataset", default="train.jsonl", type=str)
34
- parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
35
-
36
- parser.add_argument("--duration", default=4.0, type=float)
37
- parser.add_argument("--min_snr_db", default=-10, type=float)
38
- parser.add_argument("--max_snr_db", default=20, type=float)
39
-
40
- parser.add_argument("--target_sample_rate", default=8000, type=int)
41
-
42
- parser.add_argument("--max_count", default=10000, type=int)
43
-
44
- args = parser.parse_args()
45
- return args
46
-
47
-
48
- def filename_generator(data_dir: str):
49
- data_dir = Path(data_dir)
50
- for filename in data_dir.glob("**/*.wav"):
51
- yield filename.as_posix()
52
-
53
-
54
- def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
55
- data_dir = Path(data_dir)
56
- for epoch_idx in range(max_epoch):
57
- for filename in data_dir.glob("**/*.wav"):
58
- signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
- raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
-
61
- if raw_duration < duration:
62
- # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
- continue
64
- if signal.ndim != 1:
65
- raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
-
67
- signal_length = len(signal)
68
- win_size = int(duration * sample_rate)
69
- for begin in range(0, signal_length - win_size, win_size):
70
- if np.sum(signal[begin: begin+win_size]) == 0:
71
- continue
72
- row = {
73
- "epoch_idx": epoch_idx,
74
- "filename": filename.as_posix(),
75
- "raw_duration": round(raw_duration, 4),
76
- "offset": round(begin / sample_rate, 4),
77
- "duration": round(duration, 4),
78
- }
79
- yield row
80
-
81
-
82
- def main():
83
- args = get_args()
84
-
85
- file_dir = Path(args.file_dir)
86
- file_dir.mkdir(exist_ok=True)
87
-
88
- noise_dir = Path(args.noise_dir)
89
- speech_dir = Path(args.speech_dir)
90
-
91
- noise_generator = target_second_signal_generator(
92
- noise_dir.as_posix(),
93
- duration=args.duration,
94
- sample_rate=args.target_sample_rate,
95
- max_epoch=100000,
96
- )
97
- speech_generator = target_second_signal_generator(
98
- speech_dir.as_posix(),
99
- duration=args.duration,
100
- sample_rate=args.target_sample_rate,
101
- max_epoch=1,
102
- )
103
-
104
- dataset = list()
105
-
106
- count = 0
107
- process_bar = tqdm(desc="build dataset excel")
108
- with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
- for noise, speech in zip(noise_generator, speech_generator):
110
- if count >= args.max_count:
111
- break
112
-
113
- noise_filename = noise["filename"]
114
- noise_raw_duration = noise["raw_duration"]
115
- noise_offset = noise["offset"]
116
- noise_duration = noise["duration"]
117
-
118
- speech_filename = speech["filename"]
119
- speech_raw_duration = speech["raw_duration"]
120
- speech_offset = speech["offset"]
121
- speech_duration = speech["duration"]
122
-
123
- random1 = random.random()
124
- random2 = random.random()
125
-
126
- row = {
127
- "noise_filename": noise_filename,
128
- "noise_raw_duration": noise_raw_duration,
129
- "noise_offset": noise_offset,
130
- "noise_duration": noise_duration,
131
-
132
- "speech_filename": speech_filename,
133
- "speech_raw_duration": speech_raw_duration,
134
- "speech_offset": speech_offset,
135
- "speech_duration": speech_duration,
136
-
137
- "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
138
-
139
- "random1": random1,
140
- }
141
- row = json.dumps(row, ensure_ascii=False)
142
- if random2 < (1 / 300 / 1):
143
- fvalid.write(f"{row}\n")
144
- else:
145
- ftrain.write(f"{row}\n")
146
-
147
- count += 1
148
- duration_seconds = count * args.duration
149
- duration_hours = duration_seconds / 3600
150
-
151
- process_bar.update(n=1)
152
- process_bar.set_postfix({
153
- # "duration_seconds": round(duration_seconds, 4),
154
- "duration_hours": round(duration_hours, 4),
155
-
156
- })
157
-
158
- return
159
-
160
-
161
- if __name__ == "__main__":
162
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/conv_tasnet_gan/step_2_train_model.py DELETED
@@ -1,582 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- https://github.com/kaituoxu/Conv-TasNet/tree/master/src
5
-
6
- 一般场景:
7
-
8
- 目标 SI-SNR ≥ 10 dB,适用于电话通信、基础语音助手等。
9
-
10
- 高要求场景(如医疗助听、语音识别):
11
- 需 SI-SNR ≥ 14 dB,并配合 PESQ ≥ 3.0 和 STOI ≥ 0.851812。
12
-
13
- DeepFilterNet2 模型在 DNS4 数据集,超过500小时的音频上训练了 100 个 epoch。
14
- https://arxiv.org/abs/2205.05474
15
-
16
- """
17
- import argparse
18
- import json
19
- import logging
20
- from logging.handlers import TimedRotatingFileHandler
21
- import os
22
- import platform
23
- from pathlib import Path
24
- import random
25
- import sys
26
- import shutil
27
- from typing import List
28
-
29
- pwd = os.path.abspath(os.path.dirname(__file__))
30
- sys.path.append(os.path.join(pwd, "../../"))
31
-
32
- import numpy as np
33
- import torch
34
- import torch.nn as nn
35
- from torch.nn import functional as F
36
- from torch.utils.data.dataloader import DataLoader
37
- from tqdm import tqdm
38
-
39
- from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
40
- from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
41
- from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
42
- from toolbox.torchaudio.models.discriminators.waveform_metric_discriminator.modeling_waveform_metric_discriminator import WaveformMetricDiscriminatorPretrainedModel
43
- from toolbox.torchaudio.models.discriminators.waveform_metric_discriminator.configuration_waveform_metric_discriminator import WaveformMetricDiscriminatorConfig
44
- from toolbox.torchaudio.models.nx_clean_unet.metrics import run_batch_pesq, run_pesq_score
45
- from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
46
- from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
47
- from toolbox.torchaudio.losses.perceptual import NegSTOILoss, PesqLoss
48
- from toolbox.torchaudio.metrics.pesq import run_pesq_score
49
-
50
-
51
- def get_args():
52
- parser = argparse.ArgumentParser()
53
- parser.add_argument("--train_dataset", default="train.xlsx", type=str)
54
- parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
55
-
56
- parser.add_argument("--max_epochs", default=200, type=int)
57
-
58
- parser.add_argument("--batch_size", default=64, type=int)
59
- parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
60
- parser.add_argument("--patience", default=5, type=int)
61
- parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
62
- parser.add_argument("--seed", default=1234, type=int)
63
-
64
- parser.add_argument("--config_file", default="config.yaml", type=str)
65
- parser.add_argument("--discriminator_config_file", default="discriminator_config.yaml", type=str)
66
-
67
- args = parser.parse_args()
68
- return args
69
-
70
-
71
- def logging_config(file_dir: str):
72
- fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
73
-
74
- logging.basicConfig(format=fmt,
75
- datefmt="%m/%d/%Y %H:%M:%S",
76
- level=logging.INFO)
77
- file_handler = TimedRotatingFileHandler(
78
- filename=os.path.join(file_dir, "main.log"),
79
- encoding="utf-8",
80
- when="D",
81
- interval=1,
82
- backupCount=7
83
- )
84
- file_handler.setLevel(logging.INFO)
85
- file_handler.setFormatter(logging.Formatter(fmt))
86
- logger = logging.getLogger(__name__)
87
- logger.addHandler(file_handler)
88
-
89
- return logger
90
-
91
-
92
- class CollateFunction(object):
93
- def __init__(self):
94
- pass
95
-
96
- def __call__(self, batch: List[dict]):
97
- clean_audios = list()
98
- noisy_audios = list()
99
-
100
- for sample in batch:
101
- # noise_wave: torch.Tensor = sample["noise_wave"]
102
- clean_audio: torch.Tensor = sample["speech_wave"]
103
- noisy_audio: torch.Tensor = sample["mix_wave"]
104
- # snr_db: float = sample["snr_db"]
105
-
106
- clean_audios.append(clean_audio)
107
- noisy_audios.append(noisy_audio)
108
-
109
- clean_audios = torch.stack(clean_audios)
110
- noisy_audios = torch.stack(noisy_audios)
111
-
112
- # assert
113
- if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
114
- raise AssertionError("nan or inf in clean_audios")
115
- if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
116
- raise AssertionError("nan or inf in noisy_audios")
117
- return clean_audios, noisy_audios
118
-
119
-
120
- collate_fn = CollateFunction()
121
-
122
-
123
- def main():
124
- args = get_args()
125
-
126
- config = ConvTasNetConfig.from_pretrained(
127
- pretrained_model_name_or_path=args.config_file,
128
- )
129
- discriminator_config = WaveformMetricDiscriminatorConfig.from_pretrained(
130
- pretrained_model_name_or_path=args.discriminator_config_file,
131
- )
132
-
133
- serialization_dir = Path(args.serialization_dir)
134
- serialization_dir.mkdir(parents=True, exist_ok=True)
135
-
136
- logger = logging_config(serialization_dir)
137
-
138
- random.seed(args.seed)
139
- np.random.seed(args.seed)
140
- torch.manual_seed(args.seed)
141
- logger.info(f"set seed: {args.seed}")
142
-
143
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144
- n_gpu = torch.cuda.device_count()
145
- logger.info(f"GPU available count: {n_gpu}; device: {device}")
146
-
147
- # datasets
148
- train_dataset = DenoiseJsonlDataset(
149
- jsonl_file=args.train_dataset,
150
- expected_sample_rate=config.sample_rate,
151
- max_wave_value=32768.0,
152
- min_snr_db=config.min_snr_db,
153
- max_snr_db=config.max_snr_db,
154
- # skip=825000,
155
- )
156
- valid_dataset = DenoiseJsonlDataset(
157
- jsonl_file=args.valid_dataset,
158
- expected_sample_rate=config.sample_rate,
159
- max_wave_value=32768.0,
160
- min_snr_db=config.min_snr_db,
161
- max_snr_db=config.max_snr_db,
162
- )
163
- train_data_loader = DataLoader(
164
- dataset=train_dataset,
165
- batch_size=args.batch_size,
166
- # shuffle=True,
167
- sampler=None,
168
- # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
169
- num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
170
- collate_fn=collate_fn,
171
- pin_memory=False,
172
- prefetch_factor=2,
173
- )
174
- valid_data_loader = DataLoader(
175
- dataset=valid_dataset,
176
- batch_size=args.batch_size,
177
- # shuffle=True,
178
- sampler=None,
179
- # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
180
- num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
181
- collate_fn=collate_fn,
182
- pin_memory=False,
183
- prefetch_factor=2,
184
- )
185
-
186
- # models
187
- logger.info(f"prepare models. config_file: {args.config_file}")
188
- model = ConvTasNetPretrainedModel(config).to(device)
189
- model.to(device)
190
- model.train()
191
-
192
- discriminator = WaveformMetricDiscriminatorPretrainedModel(discriminator_config).to(device)
193
- discriminator.to(device)
194
- discriminator.train()
195
-
196
- # optimizer
197
- logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
198
- optimizer = torch.optim.AdamW(model.parameters(), config.lr, betas=[config.adam_b1, config.adam_b2])
199
- discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), config.lr, betas=[config.adam_b1, config.adam_b2])
200
-
201
- # resume training
202
- last_step_idx = -1
203
- last_epoch = -1
204
- for step_idx_str in serialization_dir.glob("steps-*"):
205
- step_idx_str = Path(step_idx_str)
206
- step_idx = step_idx_str.stem.split("-")[1]
207
- step_idx = int(step_idx)
208
- if step_idx > last_step_idx:
209
- last_step_idx = step_idx
210
-
211
- if last_step_idx != -1:
212
- logger.info(f"resume from steps-{last_step_idx}.")
213
- model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
214
- optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
215
-
216
- discriminator_pt = serialization_dir / f"steps-{last_step_idx}/discriminator.pt"
217
- discriminator_optimizer_pth = serialization_dir / f"steps-{last_step_idx}/discriminator_optimizer.pth"
218
-
219
- logger.info(f"load state dict for model.")
220
- with open(model_pt.as_posix(), "rb") as f:
221
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
222
- model.load_state_dict(state_dict, strict=True)
223
-
224
- if optimizer_pth.exists():
225
- logger.info(f"load state dict for optimizer.")
226
- with open(optimizer_pth.as_posix(), "rb") as f:
227
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
228
- optimizer.load_state_dict(state_dict)
229
-
230
- if discriminator_pt.exists():
231
- logger.info(f"load state dict for discriminator.")
232
- with open(model_pt.as_posix(), "rb") as f:
233
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
234
- discriminator.load_state_dict(state_dict, strict=True)
235
-
236
- if discriminator_optimizer_pth.exists():
237
- logger.info(f"load state dict for discriminator_optimizer.")
238
- with open(optimizer_pth.as_posix(), "rb") as f:
239
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
240
- discriminator_optimizer.load_state_dict(state_dict)
241
-
242
- if config.lr_scheduler == "CosineAnnealingLR":
243
- lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
244
- optimizer,
245
- last_epoch=last_epoch,
246
- # T_max=10 * config.eval_steps,
247
- # eta_min=0.01 * config.lr,
248
- **config.lr_scheduler_kwargs,
249
- )
250
- discriminator_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
251
- discriminator_optimizer,
252
- last_epoch=last_epoch,
253
- # T_max=10 * config.eval_steps,
254
- # eta_min=0.01 * config.lr,
255
- **config.lr_scheduler_kwargs,
256
- )
257
- elif config.lr_scheduler == "MultiStepLR":
258
- lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
259
- optimizer,
260
- last_epoch=last_epoch,
261
- milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
262
- )
263
- discriminator_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
264
- discriminator_optimizer,
265
- last_epoch=last_epoch,
266
- milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
267
- )
268
- else:
269
- raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
270
-
271
- ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
272
- neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
273
- neg_stoi_loss_fn = NegSTOILoss(sample_rate=config.sample_rate, reduction="mean").to(device)
274
- mr_stft_loss_fn = MultiResolutionSTFTLoss(
275
- fft_size_list=[256, 512, 1024],
276
- win_size_list=[120, 240, 480],
277
- hop_size_list=[25, 50, 100],
278
- factor_sc=1.5,
279
- factor_mag=1.0,
280
- reduction="mean"
281
- ).to(device)
282
- pesq_loss_fn = PesqLoss(0.5, sample_rate=config.sample_rate).to(device)
283
-
284
- # training loop
285
-
286
- # state
287
- average_pesq_score = 1000000000
288
- average_loss = 1000000000
289
- average_ae_loss = 1000000000
290
- average_neg_si_snr_loss = 1000000000
291
- average_neg_stoi_loss = 1000000000
292
- average_mr_stft_loss = 1000000000
293
- average_pesq_loss = 1000000000
294
- average_discriminator_g_loss = 1000000000
295
- average_discriminator_d_loss = 1000000000
296
-
297
- model_list = list()
298
- best_epoch_idx = None
299
- best_step_idx = None
300
- best_metric = None
301
- patience_count = 0
302
-
303
- step_idx = 0 if last_step_idx == -1 else last_step_idx
304
-
305
- logger.info("training")
306
- for epoch_idx in range(max(0, last_epoch+1), args.max_epochs):
307
- # train
308
- model.train()
309
-
310
- total_pesq_score = 0.
311
- total_loss = 0.
312
- total_ae_loss = 0.
313
- total_neg_si_snr_loss = 0.
314
- total_neg_stoi_loss = 0.
315
- total_mr_stft_loss = 0.
316
- total_pesq_loss = 0.
317
- total_discriminator_g_loss = 0.
318
- total_discriminator_d_loss = 0.
319
- total_batches = 0.
320
-
321
- progress_bar_train = tqdm(
322
- initial=step_idx,
323
- desc="Training; epoch-{}".format(epoch_idx),
324
- )
325
- for train_batch in train_data_loader:
326
- clean_audios, noisy_audios = train_batch
327
- clean_audios: torch.Tensor = clean_audios.to(device)
328
- noisy_audios: torch.Tensor = noisy_audios.to(device)
329
- one_labels = torch.ones(clean_audios.shape[0]).to(device)
330
-
331
- denoise_audios = model.forward(noisy_audios)
332
- denoise_audios = torch.squeeze(denoise_audios, dim=1)
333
-
334
- if torch.any(torch.isnan(denoise_audios)) or torch.any(torch.isinf(denoise_audios)):
335
- raise AssertionError("nan or inf in denoise_audios")
336
-
337
- # Discriminator
338
- clean_audio_list = torch.split(clean_audios, 1, dim=0)
339
- enhanced_audio_list = torch.split(denoise_audios, 1, dim=0)
340
- clean_audio_list = [t.squeeze().detach().cpu().numpy() for t in clean_audio_list]
341
- enhanced_audio_list = [t.squeeze().detach().cpu().numpy() for t in enhanced_audio_list]
342
-
343
- pesq_score_list: List[float] = run_batch_pesq(clean_audio_list, enhanced_audio_list, sample_rate=config.sample_rate, mode="nb")
344
-
345
- metric_r = discriminator.forward(clean_audios, clean_audios)
346
- metric_g = discriminator.forward(denoise_audios.detach(), clean_audios)
347
- loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
348
-
349
- if -1 in pesq_score_list:
350
- # print("-1 in batch_pesq_score!")
351
- loss_disc_g = 0
352
- else:
353
- pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32)
354
- loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten())
355
-
356
- discriminator_d_loss = loss_disc_r + loss_disc_g
357
- discriminator_optimizer.zero_grad()
358
- discriminator_d_loss.backward()
359
- discriminator_optimizer.step()
360
- discriminator_lr_scheduler.step()
361
-
362
- # Generator
363
- ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
364
- neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
365
- neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
366
- mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
367
- pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
368
-
369
- metric_g = discriminator.forward(denoise_audios, clean_audios)
370
- discriminator_g_loss = F.mse_loss(metric_g.flatten(), one_labels)
371
-
372
- loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss + 0.2 * discriminator_g_loss
373
- if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
374
- logger.info(f"find nan or inf in loss.")
375
- continue
376
-
377
- denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
378
- clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
379
- pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
380
-
381
- optimizer.zero_grad()
382
- loss.backward()
383
- optimizer.step()
384
- lr_scheduler.step()
385
-
386
- total_pesq_score += pesq_score
387
- total_loss += loss.item()
388
- total_ae_loss += ae_loss.item()
389
- total_neg_si_snr_loss += neg_si_snr_loss.item()
390
- total_neg_stoi_loss += neg_stoi_loss.item()
391
- total_mr_stft_loss += mr_stft_loss.item()
392
- total_pesq_loss += pesq_loss.item()
393
- total_discriminator_g_loss += discriminator_g_loss.item()
394
- total_discriminator_d_loss += discriminator_d_loss.item()
395
- total_batches += 1
396
-
397
- average_pesq_score = round(total_pesq_score / total_batches, 4)
398
- average_loss = round(total_loss / total_batches, 4)
399
- average_ae_loss = round(total_ae_loss / total_batches, 4)
400
- average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
401
- average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
402
- average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
403
- average_pesq_loss = round(total_pesq_loss / total_batches, 4)
404
- average_discriminator_g_loss = round(total_discriminator_g_loss / total_batches, 4)
405
- average_discriminator_d_loss = round(total_discriminator_d_loss / total_batches, 4)
406
-
407
- progress_bar_train.update(1)
408
- progress_bar_train.set_postfix({
409
- "lr": lr_scheduler.get_last_lr()[0],
410
- "pesq_score": average_pesq_score,
411
- "loss": average_loss,
412
- "ae_loss": average_ae_loss,
413
- "neg_si_snr_loss": average_neg_si_snr_loss,
414
- "neg_stoi_loss": average_neg_stoi_loss,
415
- "mr_stft_loss": average_mr_stft_loss,
416
- "pesq_loss": average_pesq_loss,
417
- "disc_g_loss": average_discriminator_g_loss,
418
- "disc_d_loss": average_discriminator_d_loss,
419
-
420
- })
421
-
422
- # evaluation
423
- step_idx += 1
424
- if step_idx % config.eval_steps == 0:
425
- with torch.no_grad():
426
- torch.cuda.empty_cache()
427
-
428
- total_pesq_score = 0.
429
- total_loss = 0.
430
- total_ae_loss = 0.
431
- total_neg_si_snr_loss = 0.
432
- total_neg_stoi_loss = 0.
433
- total_mr_stft_loss = 0.
434
- total_pesq_loss = 0.
435
- total_batches = 0.
436
-
437
- progress_bar_train.close()
438
- progress_bar_eval = tqdm(
439
- desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
440
- )
441
- for eval_batch in valid_data_loader:
442
- clean_audios, noisy_audios = eval_batch
443
- clean_audios = clean_audios.to(device)
444
- noisy_audios = noisy_audios.to(device)
445
-
446
- denoise_audios = model.forward(noisy_audios)
447
- denoise_audios = torch.squeeze(denoise_audios, dim=1)
448
-
449
- # Generator
450
- ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
451
- neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
452
- neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
453
- mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
454
- pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
455
-
456
- loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
457
- if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
458
- logger.info(f"find nan or inf in loss.")
459
- continue
460
-
461
- denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
462
- clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
463
- pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
464
-
465
- total_pesq_score += pesq_score
466
- total_loss += loss.item()
467
- total_ae_loss += ae_loss.item()
468
- total_neg_si_snr_loss += neg_si_snr_loss.item()
469
- total_neg_stoi_loss += neg_stoi_loss.item()
470
- total_mr_stft_loss += mr_stft_loss.item()
471
- total_pesq_loss += pesq_loss.item()
472
- total_batches += 1
473
-
474
- average_pesq_score = round(total_pesq_score / total_batches, 4)
475
- average_loss = round(total_loss / total_batches, 4)
476
- average_ae_loss = round(total_ae_loss / total_batches, 4)
477
- average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
478
- average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
479
- average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
480
- average_pesq_loss = round(total_pesq_loss / total_batches, 4)
481
-
482
- progress_bar_eval.update(1)
483
- progress_bar_eval.set_postfix({
484
- "lr": lr_scheduler.get_last_lr()[0],
485
- "pesq_score": average_pesq_score,
486
- "loss": average_loss,
487
- "ae_loss": average_ae_loss,
488
- "neg_si_snr_loss": average_neg_si_snr_loss,
489
- "neg_stoi_loss": average_neg_stoi_loss,
490
- "mr_stft_loss": average_mr_stft_loss,
491
- "pesq_loss": average_pesq_loss,
492
- })
493
-
494
- total_pesq_score = 0.
495
- total_loss = 0.
496
- total_ae_loss = 0.
497
- total_neg_si_snr_loss = 0.
498
- total_neg_stoi_loss = 0.
499
- total_mr_stft_loss = 0.
500
- total_pesq_loss = 0.
501
- total_discriminator_g_loss = 0.
502
- total_discriminator_d_loss = 0.
503
- total_batches = 0.
504
-
505
- progress_bar_eval.close()
506
- progress_bar_train = tqdm(
507
- initial=progress_bar_train.n,
508
- postfix=progress_bar_train.postfix,
509
- desc=progress_bar_train.desc,
510
- )
511
-
512
- # save path
513
- save_dir = serialization_dir / "steps-{}".format(step_idx)
514
- save_dir.mkdir(parents=True, exist_ok=False)
515
-
516
- # save models
517
- model.save_pretrained(save_dir.as_posix())
518
- discriminator.save_pretrained(save_dir.as_posix())
519
-
520
- # save optim
521
- torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
522
- torch.save(discriminator_optimizer.state_dict(), (save_dir / "discriminator_optimizer.pth").as_posix())
523
-
524
- model_list.append(save_dir)
525
- if len(model_list) >= args.num_serialized_models_to_keep:
526
- model_to_delete: Path = model_list.pop(0)
527
- shutil.rmtree(model_to_delete.as_posix())
528
-
529
- # save metric
530
- if best_metric is None:
531
- best_epoch_idx = epoch_idx
532
- best_step_idx = step_idx
533
- best_metric = average_pesq_score
534
- elif average_pesq_score > best_metric:
535
- # great is better.
536
- best_epoch_idx = epoch_idx
537
- best_step_idx = step_idx
538
- best_metric = average_pesq_score
539
- else:
540
- pass
541
-
542
- metrics = {
543
- "epoch_idx": epoch_idx,
544
- "best_epoch_idx": best_epoch_idx,
545
- "best_step_idx": best_step_idx,
546
- "pesq_score": average_pesq_score,
547
- "loss": average_loss,
548
- "ae_loss": average_ae_loss,
549
- "neg_si_snr_loss": average_neg_si_snr_loss,
550
- "neg_stoi_loss": average_neg_stoi_loss,
551
- "mr_stft_loss": average_mr_stft_loss,
552
- "pesq_loss": average_pesq_loss,
553
- }
554
- metrics_filename = save_dir / "metrics_epoch.json"
555
- with open(metrics_filename, "w", encoding="utf-8") as f:
556
- json.dump(metrics, f, indent=4, ensure_ascii=False)
557
-
558
- # save best
559
- best_dir = serialization_dir / "best"
560
- if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
561
- if best_dir.exists():
562
- shutil.rmtree(best_dir)
563
- shutil.copytree(save_dir, best_dir)
564
-
565
- # early stop
566
- early_stop_flag = False
567
- if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
568
- patience_count = 0
569
- else:
570
- patience_count += 1
571
- if patience_count >= args.patience:
572
- early_stop_flag = True
573
-
574
- # early stop
575
- if early_stop_flag:
576
- break
577
-
578
- return
579
-
580
-
581
- if __name__ == "__main__":
582
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/conv_tasnet_gan/yaml/config.yaml DELETED
@@ -1,31 +0,0 @@
1
- model_name: "conv_tasnet_gan"
2
-
3
- sample_rate: 8000
4
- segment_size: 4
5
-
6
- win_size: 20
7
- freq_bins: 256
8
- bottleneck_channels: 128
9
- num_speakers: 1
10
- num_blocks: 2
11
- num_sub_blocks: 4
12
- sub_blocks_channels: 256
13
- sub_blocks_kernel_size: 3
14
-
15
- norm_type: "gLN"
16
- causal: false
17
- mask_nonlinear: "relu"
18
-
19
- min_snr_db: -10
20
- max_snr_db: 20
21
-
22
- lr: 0.005
23
- adam_b1: 0.8
24
- adam_b2: 0.99
25
-
26
- lr_scheduler: "CosineAnnealingLR"
27
- lr_scheduler_kwargs:
28
- T_max: 250000
29
- eta_min: 0.00005
30
-
31
- eval_steps: 25000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/conv_tasnet_gan/yaml/discriminator_config.yaml DELETED
@@ -1,10 +0,0 @@
1
- model_name: "conv_tasnet_gan"
2
-
3
- sample_rate: 8000
4
- segment_size: 16000
5
- n_fft: 512
6
- win_size: 200
7
- hop_size: 80
8
-
9
- discriminator_dim: 32
10
- discriminator_in_channel: 2
 
 
 
 
 
 
 
 
 
 
 
examples/dfnet/step_2_train_model.py CHANGED
@@ -315,6 +315,7 @@ def main():
315
  # evaluation
316
  step_idx += 1
317
  if step_idx % config.eval_steps == 0:
 
318
  with torch.no_grad():
319
  torch.cuda.empty_cache()
320
 
@@ -451,6 +452,7 @@ def main():
451
  # early stop
452
  if early_stop_flag:
453
  break
 
454
 
455
  return
456
 
 
315
  # evaluation
316
  step_idx += 1
317
  if step_idx % config.eval_steps == 0:
318
+ model.eval()
319
  with torch.no_grad():
320
  torch.cuda.empty_cache()
321
 
 
452
  # early stop
453
  if early_stop_flag:
454
  break
455
+ model.train()
456
 
457
  return
458
 
examples/dfnet2/step_2_train_model.py CHANGED
@@ -318,6 +318,7 @@ def main():
318
  # evaluation
319
  step_idx += 1
320
  if step_idx % config.eval_steps == 0:
 
321
  with torch.no_grad():
322
  torch.cuda.empty_cache()
323
 
@@ -457,6 +458,7 @@ def main():
457
  # early stop
458
  if early_stop_flag:
459
  break
 
460
 
461
  return
462
 
 
318
  # evaluation
319
  step_idx += 1
320
  if step_idx % config.eval_steps == 0:
321
+ model.eval()
322
  with torch.no_grad():
323
  torch.cuda.empty_cache()
324
 
 
458
  # early stop
459
  if early_stop_flag:
460
  break
461
+ model.train()
462
 
463
  return
464
 
examples/dtln/step_2_train_model.py CHANGED
@@ -301,6 +301,7 @@ def main():
301
  # evaluation
302
  step_idx += 1
303
  if step_idx % config.eval_steps == 0:
 
304
  with torch.no_grad():
305
  torch.cuda.empty_cache()
306
 
@@ -424,6 +425,7 @@ def main():
424
  # early stop
425
  if early_stop_flag:
426
  break
 
427
 
428
  return
429
 
 
301
  # evaluation
302
  step_idx += 1
303
  if step_idx % config.eval_steps == 0:
304
+ model.eval()
305
  with torch.no_grad():
306
  torch.cuda.empty_cache()
307
 
 
425
  # early stop
426
  if early_stop_flag:
427
  break
428
+ model.train()
429
 
430
  return
431
 
examples/frcrn/step_2_train_model.py CHANGED
@@ -305,6 +305,7 @@ def main():
305
  # evaluation
306
  step_idx += 1
307
  if step_idx % config.eval_steps == 0:
 
308
  with torch.no_grad():
309
  torch.cuda.empty_cache()
310
 
@@ -428,6 +429,7 @@ def main():
428
  # early stop
429
  if early_stop_flag:
430
  break
 
431
 
432
  return
433
 
 
305
  # evaluation
306
  step_idx += 1
307
  if step_idx % config.eval_steps == 0:
308
+ model.eval()
309
  with torch.no_grad():
310
  torch.cuda.empty_cache()
311
 
 
429
  # early stop
430
  if early_stop_flag:
431
  break
432
+ model.train()
433
 
434
  return
435
 
examples/lstm/step_2_train_model.py CHANGED
@@ -314,6 +314,7 @@ def main():
314
  # evaluation
315
  step_idx += 1
316
  if step_idx % config.eval_steps == 0:
 
317
  with torch.no_grad():
318
  torch.cuda.empty_cache()
319
 
@@ -435,6 +436,7 @@ def main():
435
  # early stop
436
  if early_stop_flag:
437
  break
 
438
  return
439
 
440
 
 
314
  # evaluation
315
  step_idx += 1
316
  if step_idx % config.eval_steps == 0:
317
+ model.eval()
318
  with torch.no_grad():
319
  torch.cuda.empty_cache()
320
 
 
436
  # early stop
437
  if early_stop_flag:
438
  break
439
+ model.train()
440
  return
441
 
442
 
examples/rnnoise/step_2_train_model.py CHANGED
@@ -314,6 +314,7 @@ def main():
314
  # evaluation
315
  step_idx += 1
316
  if step_idx % config.eval_steps == 0:
 
317
  with torch.no_grad():
318
  torch.cuda.empty_cache()
319
 
@@ -435,6 +436,8 @@ def main():
435
  # early stop
436
  if early_stop_flag:
437
  break
 
 
438
  return
439
 
440
 
 
314
  # evaluation
315
  step_idx += 1
316
  if step_idx % config.eval_steps == 0:
317
+ model.eval()
318
  with torch.no_grad():
319
  torch.cuda.empty_cache()
320
 
 
436
  # early stop
437
  if early_stop_flag:
438
  break
439
+ model.train()
440
+
441
  return
442
 
443
 
toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py CHANGED
@@ -1047,6 +1047,9 @@ class DfNet2(nn.Module):
1047
  feat_spec = feat_spec[..., :self.df_decoder.df_bins]
1048
  # feat_spec shape: [b, 2, t, df_bins]
1049
 
 
 
 
1050
  return spec, feat_erb, feat_spec
1051
 
1052
  def forward(self,
 
1047
  feat_spec = feat_spec[..., :self.df_decoder.df_bins]
1048
  # feat_spec shape: [b, 2, t, df_bins]
1049
 
1050
+ spec = spec.detach()
1051
+ feat_erb = feat_erb.detach()
1052
+ feat_spec = feat_spec.detach()
1053
  return spec, feat_erb, feat_spec
1054
 
1055
  def forward(self,
toolbox/torchaudio/modules/utils/ema.py CHANGED
@@ -1,8 +1,95 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
 
 
 
3
  import torch.nn as nn
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class ExponentialMovingAverage(nn.Module):
7
  def __init__(self):
8
  super().__init__()
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
+ import math
4
+
5
+ import numpy as np
6
  import torch.nn as nn
7
 
8
 
9
+ def _calculate_norm_alpha(sample_rate: int, hop_size: int, tau: float):
10
+ """Exponential decay factor alpha for a given tau (decay window size [s])."""
11
+ dt = hop_size / sample_rate
12
+ result = math.exp(-dt / tau)
13
+ return result
14
+
15
+
16
+ def get_norm_alpha(sample_rate: int, hop_size: int, norm_tau: float) -> float:
17
+ a_ = _calculate_norm_alpha(sample_rate=sample_rate, hop_size=hop_size, tau=norm_tau)
18
+
19
+ precision = 3
20
+ a = 1.0
21
+ while a >= 1.0:
22
+ a = round(a_, precision)
23
+ precision += 1
24
+
25
+ return a
26
+
27
+
28
+ MEAN_NORM_INIT = [-60., -90.]
29
+
30
+
31
+ def make_erb_norm_state(erb_bins: int, channels: int) -> np.ndarray:
32
+ state = np.linspace(MEAN_NORM_INIT[0], MEAN_NORM_INIT[1], erb_bins)
33
+ state = np.expand_dims(state, axis=0)
34
+ state = np.repeat(state, channels, axis=0)
35
+
36
+ # state shape: (audio_channels, erb_bins)
37
+ return state
38
+
39
+
40
+ def erb_normalize(erb_feat: np.ndarray, alpha: float, state: np.ndarray = None):
41
+ erb_feat = np.copy(erb_feat)
42
+ batch_size, time_steps, erb_bins = erb_feat.shape
43
+
44
+ if state is None:
45
+ state = make_erb_norm_state(erb_bins, erb_feat.shape[0])
46
+ # state = np.linspace(MEAN_NORM_INIT[0], MEAN_NORM_INIT[1], erb_bins)
47
+ # state = np.expand_dims(state, axis=0)
48
+ # state = np.repeat(state, erb_feat.shape[0], axis=0)
49
+
50
+ for i in range(batch_size):
51
+ for j in range(time_steps):
52
+ for k in range(erb_bins):
53
+ x = erb_feat[i][j][k]
54
+ s = state[i][k]
55
+
56
+ state[i][k] = x * (1. - alpha) + s * alpha
57
+ erb_feat[i][j][k] -= state[i][k]
58
+ erb_feat[i][j][k] /= 40.
59
+
60
+ return erb_feat
61
+
62
+
63
+ UNIT_NORM_INIT = [0.001, 0.0001]
64
+
65
+
66
+ def make_spec_norm_state(df_bins: int, channels: int) -> np.ndarray:
67
+ state = np.linspace(UNIT_NORM_INIT[0], UNIT_NORM_INIT[1], df_bins)
68
+ state = np.expand_dims(state, axis=0)
69
+ state = np.repeat(state, channels, axis=0)
70
+
71
+ # state shape: (audio_channels, df_bins)
72
+ return state
73
+
74
+
75
+ def spec_normalize(spec_feat: np.ndarray, alpha: float, state: np.ndarray = None):
76
+ spec_feat = np.copy(spec_feat)
77
+ batch_size, time_steps, df_bins = spec_feat.shape
78
+
79
+ if state is None:
80
+ state = make_spec_norm_state(df_bins, spec_feat.shape[0])
81
+
82
+ for i in range(batch_size):
83
+ for j in range(time_steps):
84
+ for k in range(df_bins):
85
+ x = spec_feat[i][j][k]
86
+ s = state[i][k]
87
+
88
+ state[i][k] = np.abs(x) * (1. - alpha) + s * alpha
89
+ spec_feat[i][j][k] /= np.sqrt(state[i][k])
90
+ return spec_feat
91
+
92
+
93
  class ExponentialMovingAverage(nn.Module):
94
  def __init__(self):
95
  super().__init__()