Spaces:
Running
Running
add ema
Browse files
examples/dtln/run.sh
CHANGED
@@ -6,7 +6,7 @@ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name f
|
|
6 |
--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
|
7 |
--speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
|
8 |
|
9 |
-
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dtln-dns3 \
|
10 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
11 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
12 |
|
|
|
6 |
--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
|
7 |
--speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
|
8 |
|
9 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dtln-nx-dns3 \
|
10 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
11 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
12 |
|
examples/rnnoise/run.sh
CHANGED
@@ -8,7 +8,8 @@ sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name f
|
|
8 |
|
9 |
sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \
|
10 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
11 |
-
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
|
|
12 |
|
13 |
|
14 |
END
|
|
|
8 |
|
9 |
sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \
|
10 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
11 |
+
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
|
12 |
+
--sparse
|
13 |
|
14 |
|
15 |
END
|
examples/rnnoise/step_2_train_model.py
CHANGED
@@ -48,6 +48,8 @@ def get_args():
|
|
48 |
|
49 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
50 |
|
|
|
|
|
51 |
args = parser.parse_args()
|
52 |
return args
|
53 |
|
@@ -289,6 +291,8 @@ def main():
|
|
289 |
loss.backward()
|
290 |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
291 |
optimizer.step()
|
|
|
|
|
292 |
lr_scheduler.step()
|
293 |
|
294 |
total_pesq_score += pesq_score
|
|
|
48 |
|
49 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
50 |
|
51 |
+
parser.add_argument("--sparse", action="store_true")
|
52 |
+
|
53 |
args = parser.parse_args()
|
54 |
return args
|
55 |
|
|
|
291 |
loss.backward()
|
292 |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
293 |
optimizer.step()
|
294 |
+
if args.sparse:
|
295 |
+
model.sparsify()
|
296 |
lr_scheduler.step()
|
297 |
|
298 |
total_pesq_score += pesq_score
|
toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py
CHANGED
@@ -24,6 +24,7 @@ from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config
|
|
24 |
from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
|
25 |
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
|
26 |
from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands
|
|
|
27 |
|
28 |
|
29 |
MODEL_FILE = "model.pt"
|
@@ -965,13 +966,6 @@ class DfNet2(nn.Module):
|
|
965 |
self.hop_size = config.hop_size
|
966 |
self.win_type = config.win_type
|
967 |
|
968 |
-
self.erb_bands = ErbBands(
|
969 |
-
sample_rate=config.sample_rate,
|
970 |
-
nfft=config.nfft,
|
971 |
-
erb_bins=config.erb_bins,
|
972 |
-
min_freq_bins_for_erb=config.min_freq_bins_for_erb,
|
973 |
-
)
|
974 |
-
|
975 |
self.stft = ConvSTFT(
|
976 |
nfft=config.nfft,
|
977 |
win_size=config.win_size,
|
@@ -988,6 +982,24 @@ class DfNet2(nn.Module):
|
|
988 |
requires_grad=False
|
989 |
)
|
990 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
991 |
self.encoder = Encoder(config)
|
992 |
self.erb_decoder = ErbDecoder(config)
|
993 |
|
@@ -1052,6 +1064,24 @@ class DfNet2(nn.Module):
|
|
1052 |
feat_spec = feat_spec.detach()
|
1053 |
return spec, feat_erb, feat_spec
|
1054 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1055 |
def forward(self,
|
1056 |
noisy: torch.Tensor,
|
1057 |
):
|
@@ -1067,6 +1097,7 @@ class DfNet2(nn.Module):
|
|
1067 |
noisy = self.signal_prepare(noisy)
|
1068 |
|
1069 |
spec, feat_erb, feat_spec = self.feature_prepare(noisy)
|
|
|
1070 |
|
1071 |
e0, e1, e2, e3, emb, c0, lsnr, _ = self.encoder.forward(feat_erb, feat_spec)
|
1072 |
|
@@ -1137,6 +1168,7 @@ class DfNet2(nn.Module):
|
|
1137 |
cache_dict3 = None
|
1138 |
cache_dict4 = None
|
1139 |
cache_dict5 = None
|
|
|
1140 |
|
1141 |
waveform_list = list()
|
1142 |
for i in range(int(t)):
|
@@ -1148,6 +1180,7 @@ class DfNet2(nn.Module):
|
|
1148 |
# spec shape: [b, 1, t, f, 2]
|
1149 |
# feat_erb shape: [b, 1, t, erb_bins]
|
1150 |
# feat_spec shape: [b, 2, t, df_bins]
|
|
|
1151 |
|
1152 |
e0, e1, e2, e3, emb, c0, lsnr, cache_dict0 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict0)
|
1153 |
|
@@ -1174,10 +1207,6 @@ class DfNet2(nn.Module):
|
|
1174 |
spec_f, cache_dict3 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict3)
|
1175 |
# spec_f shape: [b, 1, t, df_bins, 2], torch.float32
|
1176 |
|
1177 |
-
spec_e = torch.concat(tensors=[
|
1178 |
-
spec_f, spec_m[..., self.df_decoder.df_bins:, :]
|
1179 |
-
], dim=3)
|
1180 |
-
|
1181 |
spec_e, cache_dict4 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict4)
|
1182 |
|
1183 |
spec_e = torch.squeeze(spec_e, dim=1)
|
|
|
24 |
from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
|
25 |
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
|
26 |
from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands
|
27 |
+
from toolbox.torchaudio.modules.utils.ema import ErbEMA, SpecEMA
|
28 |
|
29 |
|
30 |
MODEL_FILE = "model.pt"
|
|
|
966 |
self.hop_size = config.hop_size
|
967 |
self.win_type = config.win_type
|
968 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
969 |
self.stft = ConvSTFT(
|
970 |
nfft=config.nfft,
|
971 |
win_size=config.win_size,
|
|
|
982 |
requires_grad=False
|
983 |
)
|
984 |
|
985 |
+
self.erb_bands = ErbBands(
|
986 |
+
sample_rate=config.sample_rate,
|
987 |
+
nfft=config.nfft,
|
988 |
+
erb_bins=config.erb_bins,
|
989 |
+
min_freq_bins_for_erb=config.min_freq_bins_for_erb,
|
990 |
+
)
|
991 |
+
|
992 |
+
self.erb_ema = ErbEMA(
|
993 |
+
sample_rate=config.sample_rate,
|
994 |
+
hop_size=config.hop_size,
|
995 |
+
erb_bins=config.erb_bins,
|
996 |
+
)
|
997 |
+
self.spec_ema = SpecEMA(
|
998 |
+
sample_rate=config.sample_rate,
|
999 |
+
hop_size=config.hop_size,
|
1000 |
+
df_bins=config.df_bins,
|
1001 |
+
)
|
1002 |
+
|
1003 |
self.encoder = Encoder(config)
|
1004 |
self.erb_decoder = ErbDecoder(config)
|
1005 |
|
|
|
1064 |
feat_spec = feat_spec.detach()
|
1065 |
return spec, feat_erb, feat_spec
|
1066 |
|
1067 |
+
def feature_norm(self, feat_erb, feat_spec, cache_dict: dict = None):
|
1068 |
+
if cache_dict is None:
|
1069 |
+
cache_dict = defaultdict(lambda: None)
|
1070 |
+
cache0 = cache_dict["cache0"]
|
1071 |
+
cache1 = cache_dict["cache1"]
|
1072 |
+
|
1073 |
+
feat_erb, new_cache0 = self.erb_ema.norm(feat_erb, state=cache0)
|
1074 |
+
feat_spec, new_cache1 = self.spec_ema.norm(feat_spec, state=cache1)
|
1075 |
+
|
1076 |
+
new_cache_dict = {
|
1077 |
+
"cache0": new_cache0,
|
1078 |
+
"cache1": new_cache1,
|
1079 |
+
}
|
1080 |
+
|
1081 |
+
feat_erb = feat_erb.detach()
|
1082 |
+
feat_spec = feat_spec.detach()
|
1083 |
+
return feat_erb, feat_spec, new_cache_dict
|
1084 |
+
|
1085 |
def forward(self,
|
1086 |
noisy: torch.Tensor,
|
1087 |
):
|
|
|
1097 |
noisy = self.signal_prepare(noisy)
|
1098 |
|
1099 |
spec, feat_erb, feat_spec = self.feature_prepare(noisy)
|
1100 |
+
feat_erb, feat_spec, _ = self.feature_norm(feat_erb, feat_spec)
|
1101 |
|
1102 |
e0, e1, e2, e3, emb, c0, lsnr, _ = self.encoder.forward(feat_erb, feat_spec)
|
1103 |
|
|
|
1168 |
cache_dict3 = None
|
1169 |
cache_dict4 = None
|
1170 |
cache_dict5 = None
|
1171 |
+
cache_dict6 = None
|
1172 |
|
1173 |
waveform_list = list()
|
1174 |
for i in range(int(t)):
|
|
|
1180 |
# spec shape: [b, 1, t, f, 2]
|
1181 |
# feat_erb shape: [b, 1, t, erb_bins]
|
1182 |
# feat_spec shape: [b, 2, t, df_bins]
|
1183 |
+
feat_erb, feat_spec, cache_dict6 = self.feature_norm(feat_erb, feat_spec, cache_dict=cache_dict6)
|
1184 |
|
1185 |
e0, e1, e2, e3, emb, c0, lsnr, cache_dict0 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict0)
|
1186 |
|
|
|
1207 |
spec_f, cache_dict3 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict3)
|
1208 |
# spec_f shape: [b, 1, t, df_bins, 2], torch.float32
|
1209 |
|
|
|
|
|
|
|
|
|
1210 |
spec_e, cache_dict4 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict4)
|
1211 |
|
1212 |
spec_e = torch.squeeze(spec_e, dim=1)
|
toolbox/torchaudio/modules/utils/ema.py
CHANGED
@@ -3,26 +3,133 @@
|
|
3 |
import math
|
4 |
|
5 |
import numpy as np
|
|
|
6 |
import torch.nn as nn
|
7 |
|
8 |
|
9 |
-
|
10 |
-
"""Exponential decay factor alpha for a given tau (decay window size [s])."""
|
11 |
-
dt = hop_size / sample_rate
|
12 |
-
result = math.exp(-dt / tau)
|
13 |
-
return result
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
|
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
MEAN_NORM_INIT = [-60., -90.]
|
@@ -90,10 +197,5 @@ def spec_normalize(spec_feat: np.ndarray, alpha: float, state: np.ndarray = None
|
|
90 |
return spec_feat
|
91 |
|
92 |
|
93 |
-
class ExponentialMovingAverage(nn.Module):
|
94 |
-
def __init__(self):
|
95 |
-
super().__init__()
|
96 |
-
|
97 |
-
|
98 |
if __name__ == "__main__":
|
99 |
pass
|
|
|
3 |
import math
|
4 |
|
5 |
import numpy as np
|
6 |
+
import torch
|
7 |
import torch.nn as nn
|
8 |
|
9 |
|
10 |
+
class EMANumpy(object):
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
@classmethod
|
13 |
+
def _calculate_norm_alpha(cls, sample_rate: int, hop_size: int, tau: float):
|
14 |
+
"""Exponential decay factor alpha for a given tau (decay window size [s])."""
|
15 |
+
dt = hop_size / sample_rate
|
16 |
+
result = math.exp(-dt / tau)
|
17 |
+
return result
|
18 |
|
19 |
+
@classmethod
|
20 |
+
def get_norm_alpha(cls, sample_rate: int, hop_size: int, norm_tau: float) -> float:
|
21 |
+
a_ = cls._calculate_norm_alpha(sample_rate=sample_rate, hop_size=hop_size, tau=norm_tau)
|
22 |
|
23 |
+
precision = 3
|
24 |
+
a = 1.0
|
25 |
+
while a >= 1.0:
|
26 |
+
a = round(a_, precision)
|
27 |
+
precision += 1
|
28 |
|
29 |
+
return a
|
30 |
+
|
31 |
+
|
32 |
+
class ErbEMA(nn.Module, EMANumpy):
|
33 |
+
def __init__(self,
|
34 |
+
sample_rate: int = 8000,
|
35 |
+
hop_size: int = 80,
|
36 |
+
erb_bins: int = 32,
|
37 |
+
mean_norm_init_start: float = -60.,
|
38 |
+
mean_norm_init_end: float = -90.,
|
39 |
+
norm_tau: float = 1.,
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
self.sample_rate = sample_rate
|
43 |
+
self.hop_size = hop_size
|
44 |
+
self.erb_bins = erb_bins
|
45 |
+
self.mean_norm_init_start = mean_norm_init_start
|
46 |
+
self.mean_norm_init_end = mean_norm_init_end
|
47 |
+
self.norm_tau = norm_tau
|
48 |
+
|
49 |
+
self.alpha = self.get_norm_alpha(sample_rate, hop_size, norm_tau)
|
50 |
+
|
51 |
+
def make_erb_norm_state(self) -> torch.Tensor:
|
52 |
+
state = torch.linspace(start=self.mean_norm_init_start, end=self.mean_norm_init_end,
|
53 |
+
steps=self.erb_bins)
|
54 |
+
state = state.unsqueeze(0).unsqueeze(0)
|
55 |
+
# state shape: [b, c, erb_bins]
|
56 |
+
# state shape: [1, 1, erb_bins]
|
57 |
+
return state
|
58 |
+
|
59 |
+
def norm(self,
|
60 |
+
feat_erb: torch.Tensor,
|
61 |
+
state: torch.Tensor = None,
|
62 |
+
):
|
63 |
+
feat_erb = feat_erb.clone()
|
64 |
+
b, c, t, f = feat_erb.shape
|
65 |
+
|
66 |
+
# erb_feat shape: [b, c, t, f]
|
67 |
+
if state is None:
|
68 |
+
state = self.make_erb_norm_state()
|
69 |
+
state = state.clone()
|
70 |
+
|
71 |
+
for j in range(t):
|
72 |
+
current = feat_erb[:, :, j, :]
|
73 |
+
new_state = current * (1 - self.alpha) + state * self.alpha
|
74 |
+
|
75 |
+
feat_erb[:, :, j, :] = (current - new_state) / 40.0
|
76 |
+
state = new_state
|
77 |
+
|
78 |
+
return feat_erb, state
|
79 |
+
|
80 |
+
|
81 |
+
class SpecEMA(nn.Module, EMANumpy):
|
82 |
+
"""
|
83 |
+
https://github.com/grazder/DeepFilterNet/blob/torchDF_main/libDF/src/lib.rs
|
84 |
+
"""
|
85 |
+
def __init__(self,
|
86 |
+
sample_rate: int = 8000,
|
87 |
+
hop_size: int = 80,
|
88 |
+
df_bins: int = 96,
|
89 |
+
unit_norm_init_start: float = 0.001,
|
90 |
+
unit_norm_init_end: float = 0.0001,
|
91 |
+
norm_tau: float = 1.,
|
92 |
+
):
|
93 |
+
super().__init__()
|
94 |
+
self.sample_rate = sample_rate
|
95 |
+
self.hop_size = hop_size
|
96 |
+
self.df_bins = df_bins
|
97 |
+
self.unit_norm_init_start = unit_norm_init_start
|
98 |
+
self.unit_norm_init_end = unit_norm_init_end
|
99 |
+
self.norm_tau = norm_tau
|
100 |
+
|
101 |
+
self.alpha = self.get_norm_alpha(sample_rate, hop_size, norm_tau)
|
102 |
+
|
103 |
+
def make_spec_norm_state(self) -> torch.Tensor:
|
104 |
+
state = torch.linspace(start=self.unit_norm_init_start, end=self.unit_norm_init_end,
|
105 |
+
steps=self.df_bins)
|
106 |
+
state = state.unsqueeze(0).unsqueeze(0)
|
107 |
+
# state shape: [b, c, df_bins]
|
108 |
+
# state shape: [1, 1, df_bins]
|
109 |
+
return state
|
110 |
+
|
111 |
+
def norm(self,
|
112 |
+
feat_spec: torch.Tensor,
|
113 |
+
state: torch.Tensor = None,
|
114 |
+
):
|
115 |
+
feat_spec = feat_spec.clone()
|
116 |
+
b, c, t, f = feat_spec.shape
|
117 |
+
|
118 |
+
# feat_spec shape: [b, 2, t, df_bins]
|
119 |
+
if state is None:
|
120 |
+
state = self.make_spec_norm_state()
|
121 |
+
state = state.clone()
|
122 |
+
|
123 |
+
for j in range(t):
|
124 |
+
current = feat_spec[:, :, j, :]
|
125 |
+
current_abs = torch.sum(torch.square(current), dim=1, keepdim=True)
|
126 |
+
# current_abs shape: [b, 1, df_bins]
|
127 |
+
new_state = current_abs * (1 - self.alpha) + state * self.alpha
|
128 |
+
|
129 |
+
feat_spec[:, :, j, :] = current / torch.sqrt(new_state)
|
130 |
+
state = new_state
|
131 |
+
|
132 |
+
return feat_spec, state
|
133 |
|
134 |
|
135 |
MEAN_NORM_INIT = [-60., -90.]
|
|
|
197 |
return spec_feat
|
198 |
|
199 |
|
|
|
|
|
|
|
|
|
|
|
200 |
if __name__ == "__main__":
|
201 |
pass
|