HoneyTian commited on
Commit
169d6d6
·
1 Parent(s): 2171fed
examples/dtln/run.sh CHANGED
@@ -6,7 +6,7 @@ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name f
6
  --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
7
  --speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
8
 
9
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dtln-dns3 \
10
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
11
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
12
 
 
6
  --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
7
  --speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
8
 
9
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dtln-nx-dns3 \
10
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
11
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
12
 
examples/rnnoise/run.sh CHANGED
@@ -8,7 +8,8 @@ sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name f
8
 
9
  sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \
10
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
- --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
 
12
 
13
 
14
  END
 
8
 
9
  sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \
10
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
12
+ --sparse
13
 
14
 
15
  END
examples/rnnoise/step_2_train_model.py CHANGED
@@ -48,6 +48,8 @@ def get_args():
48
 
49
  parser.add_argument("--config_file", default="config.yaml", type=str)
50
 
 
 
51
  args = parser.parse_args()
52
  return args
53
 
@@ -289,6 +291,8 @@ def main():
289
  loss.backward()
290
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
291
  optimizer.step()
 
 
292
  lr_scheduler.step()
293
 
294
  total_pesq_score += pesq_score
 
48
 
49
  parser.add_argument("--config_file", default="config.yaml", type=str)
50
 
51
+ parser.add_argument("--sparse", action="store_true")
52
+
53
  args = parser.parse_args()
54
  return args
55
 
 
291
  loss.backward()
292
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
293
  optimizer.step()
294
+ if args.sparse:
295
+ model.sparsify()
296
  lr_scheduler.step()
297
 
298
  total_pesq_score += pesq_score
toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py CHANGED
@@ -24,6 +24,7 @@ from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config
24
  from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
25
  from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
26
  from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands
 
27
 
28
 
29
  MODEL_FILE = "model.pt"
@@ -965,13 +966,6 @@ class DfNet2(nn.Module):
965
  self.hop_size = config.hop_size
966
  self.win_type = config.win_type
967
 
968
- self.erb_bands = ErbBands(
969
- sample_rate=config.sample_rate,
970
- nfft=config.nfft,
971
- erb_bins=config.erb_bins,
972
- min_freq_bins_for_erb=config.min_freq_bins_for_erb,
973
- )
974
-
975
  self.stft = ConvSTFT(
976
  nfft=config.nfft,
977
  win_size=config.win_size,
@@ -988,6 +982,24 @@ class DfNet2(nn.Module):
988
  requires_grad=False
989
  )
990
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
991
  self.encoder = Encoder(config)
992
  self.erb_decoder = ErbDecoder(config)
993
 
@@ -1052,6 +1064,24 @@ class DfNet2(nn.Module):
1052
  feat_spec = feat_spec.detach()
1053
  return spec, feat_erb, feat_spec
1054
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1055
  def forward(self,
1056
  noisy: torch.Tensor,
1057
  ):
@@ -1067,6 +1097,7 @@ class DfNet2(nn.Module):
1067
  noisy = self.signal_prepare(noisy)
1068
 
1069
  spec, feat_erb, feat_spec = self.feature_prepare(noisy)
 
1070
 
1071
  e0, e1, e2, e3, emb, c0, lsnr, _ = self.encoder.forward(feat_erb, feat_spec)
1072
 
@@ -1137,6 +1168,7 @@ class DfNet2(nn.Module):
1137
  cache_dict3 = None
1138
  cache_dict4 = None
1139
  cache_dict5 = None
 
1140
 
1141
  waveform_list = list()
1142
  for i in range(int(t)):
@@ -1148,6 +1180,7 @@ class DfNet2(nn.Module):
1148
  # spec shape: [b, 1, t, f, 2]
1149
  # feat_erb shape: [b, 1, t, erb_bins]
1150
  # feat_spec shape: [b, 2, t, df_bins]
 
1151
 
1152
  e0, e1, e2, e3, emb, c0, lsnr, cache_dict0 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict0)
1153
 
@@ -1174,10 +1207,6 @@ class DfNet2(nn.Module):
1174
  spec_f, cache_dict3 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict3)
1175
  # spec_f shape: [b, 1, t, df_bins, 2], torch.float32
1176
 
1177
- spec_e = torch.concat(tensors=[
1178
- spec_f, spec_m[..., self.df_decoder.df_bins:, :]
1179
- ], dim=3)
1180
-
1181
  spec_e, cache_dict4 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict4)
1182
 
1183
  spec_e = torch.squeeze(spec_e, dim=1)
 
24
  from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
25
  from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
26
  from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands
27
+ from toolbox.torchaudio.modules.utils.ema import ErbEMA, SpecEMA
28
 
29
 
30
  MODEL_FILE = "model.pt"
 
966
  self.hop_size = config.hop_size
967
  self.win_type = config.win_type
968
 
 
 
 
 
 
 
 
969
  self.stft = ConvSTFT(
970
  nfft=config.nfft,
971
  win_size=config.win_size,
 
982
  requires_grad=False
983
  )
984
 
985
+ self.erb_bands = ErbBands(
986
+ sample_rate=config.sample_rate,
987
+ nfft=config.nfft,
988
+ erb_bins=config.erb_bins,
989
+ min_freq_bins_for_erb=config.min_freq_bins_for_erb,
990
+ )
991
+
992
+ self.erb_ema = ErbEMA(
993
+ sample_rate=config.sample_rate,
994
+ hop_size=config.hop_size,
995
+ erb_bins=config.erb_bins,
996
+ )
997
+ self.spec_ema = SpecEMA(
998
+ sample_rate=config.sample_rate,
999
+ hop_size=config.hop_size,
1000
+ df_bins=config.df_bins,
1001
+ )
1002
+
1003
  self.encoder = Encoder(config)
1004
  self.erb_decoder = ErbDecoder(config)
1005
 
 
1064
  feat_spec = feat_spec.detach()
1065
  return spec, feat_erb, feat_spec
1066
 
1067
+ def feature_norm(self, feat_erb, feat_spec, cache_dict: dict = None):
1068
+ if cache_dict is None:
1069
+ cache_dict = defaultdict(lambda: None)
1070
+ cache0 = cache_dict["cache0"]
1071
+ cache1 = cache_dict["cache1"]
1072
+
1073
+ feat_erb, new_cache0 = self.erb_ema.norm(feat_erb, state=cache0)
1074
+ feat_spec, new_cache1 = self.spec_ema.norm(feat_spec, state=cache1)
1075
+
1076
+ new_cache_dict = {
1077
+ "cache0": new_cache0,
1078
+ "cache1": new_cache1,
1079
+ }
1080
+
1081
+ feat_erb = feat_erb.detach()
1082
+ feat_spec = feat_spec.detach()
1083
+ return feat_erb, feat_spec, new_cache_dict
1084
+
1085
  def forward(self,
1086
  noisy: torch.Tensor,
1087
  ):
 
1097
  noisy = self.signal_prepare(noisy)
1098
 
1099
  spec, feat_erb, feat_spec = self.feature_prepare(noisy)
1100
+ feat_erb, feat_spec, _ = self.feature_norm(feat_erb, feat_spec)
1101
 
1102
  e0, e1, e2, e3, emb, c0, lsnr, _ = self.encoder.forward(feat_erb, feat_spec)
1103
 
 
1168
  cache_dict3 = None
1169
  cache_dict4 = None
1170
  cache_dict5 = None
1171
+ cache_dict6 = None
1172
 
1173
  waveform_list = list()
1174
  for i in range(int(t)):
 
1180
  # spec shape: [b, 1, t, f, 2]
1181
  # feat_erb shape: [b, 1, t, erb_bins]
1182
  # feat_spec shape: [b, 2, t, df_bins]
1183
+ feat_erb, feat_spec, cache_dict6 = self.feature_norm(feat_erb, feat_spec, cache_dict=cache_dict6)
1184
 
1185
  e0, e1, e2, e3, emb, c0, lsnr, cache_dict0 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict0)
1186
 
 
1207
  spec_f, cache_dict3 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict3)
1208
  # spec_f shape: [b, 1, t, df_bins, 2], torch.float32
1209
 
 
 
 
 
1210
  spec_e, cache_dict4 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict4)
1211
 
1212
  spec_e = torch.squeeze(spec_e, dim=1)
toolbox/torchaudio/modules/utils/ema.py CHANGED
@@ -3,26 +3,133 @@
3
  import math
4
 
5
  import numpy as np
 
6
  import torch.nn as nn
7
 
8
 
9
- def _calculate_norm_alpha(sample_rate: int, hop_size: int, tau: float):
10
- """Exponential decay factor alpha for a given tau (decay window size [s])."""
11
- dt = hop_size / sample_rate
12
- result = math.exp(-dt / tau)
13
- return result
14
 
 
 
 
 
 
 
15
 
16
- def get_norm_alpha(sample_rate: int, hop_size: int, norm_tau: float) -> float:
17
- a_ = _calculate_norm_alpha(sample_rate=sample_rate, hop_size=hop_size, tau=norm_tau)
 
18
 
19
- precision = 3
20
- a = 1.0
21
- while a >= 1.0:
22
- a = round(a_, precision)
23
- precision += 1
24
 
25
- return a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  MEAN_NORM_INIT = [-60., -90.]
@@ -90,10 +197,5 @@ def spec_normalize(spec_feat: np.ndarray, alpha: float, state: np.ndarray = None
90
  return spec_feat
91
 
92
 
93
- class ExponentialMovingAverage(nn.Module):
94
- def __init__(self):
95
- super().__init__()
96
-
97
-
98
  if __name__ == "__main__":
99
  pass
 
3
  import math
4
 
5
  import numpy as np
6
+ import torch
7
  import torch.nn as nn
8
 
9
 
10
+ class EMANumpy(object):
 
 
 
 
11
 
12
+ @classmethod
13
+ def _calculate_norm_alpha(cls, sample_rate: int, hop_size: int, tau: float):
14
+ """Exponential decay factor alpha for a given tau (decay window size [s])."""
15
+ dt = hop_size / sample_rate
16
+ result = math.exp(-dt / tau)
17
+ return result
18
 
19
+ @classmethod
20
+ def get_norm_alpha(cls, sample_rate: int, hop_size: int, norm_tau: float) -> float:
21
+ a_ = cls._calculate_norm_alpha(sample_rate=sample_rate, hop_size=hop_size, tau=norm_tau)
22
 
23
+ precision = 3
24
+ a = 1.0
25
+ while a >= 1.0:
26
+ a = round(a_, precision)
27
+ precision += 1
28
 
29
+ return a
30
+
31
+
32
+ class ErbEMA(nn.Module, EMANumpy):
33
+ def __init__(self,
34
+ sample_rate: int = 8000,
35
+ hop_size: int = 80,
36
+ erb_bins: int = 32,
37
+ mean_norm_init_start: float = -60.,
38
+ mean_norm_init_end: float = -90.,
39
+ norm_tau: float = 1.,
40
+ ):
41
+ super().__init__()
42
+ self.sample_rate = sample_rate
43
+ self.hop_size = hop_size
44
+ self.erb_bins = erb_bins
45
+ self.mean_norm_init_start = mean_norm_init_start
46
+ self.mean_norm_init_end = mean_norm_init_end
47
+ self.norm_tau = norm_tau
48
+
49
+ self.alpha = self.get_norm_alpha(sample_rate, hop_size, norm_tau)
50
+
51
+ def make_erb_norm_state(self) -> torch.Tensor:
52
+ state = torch.linspace(start=self.mean_norm_init_start, end=self.mean_norm_init_end,
53
+ steps=self.erb_bins)
54
+ state = state.unsqueeze(0).unsqueeze(0)
55
+ # state shape: [b, c, erb_bins]
56
+ # state shape: [1, 1, erb_bins]
57
+ return state
58
+
59
+ def norm(self,
60
+ feat_erb: torch.Tensor,
61
+ state: torch.Tensor = None,
62
+ ):
63
+ feat_erb = feat_erb.clone()
64
+ b, c, t, f = feat_erb.shape
65
+
66
+ # erb_feat shape: [b, c, t, f]
67
+ if state is None:
68
+ state = self.make_erb_norm_state()
69
+ state = state.clone()
70
+
71
+ for j in range(t):
72
+ current = feat_erb[:, :, j, :]
73
+ new_state = current * (1 - self.alpha) + state * self.alpha
74
+
75
+ feat_erb[:, :, j, :] = (current - new_state) / 40.0
76
+ state = new_state
77
+
78
+ return feat_erb, state
79
+
80
+
81
+ class SpecEMA(nn.Module, EMANumpy):
82
+ """
83
+ https://github.com/grazder/DeepFilterNet/blob/torchDF_main/libDF/src/lib.rs
84
+ """
85
+ def __init__(self,
86
+ sample_rate: int = 8000,
87
+ hop_size: int = 80,
88
+ df_bins: int = 96,
89
+ unit_norm_init_start: float = 0.001,
90
+ unit_norm_init_end: float = 0.0001,
91
+ norm_tau: float = 1.,
92
+ ):
93
+ super().__init__()
94
+ self.sample_rate = sample_rate
95
+ self.hop_size = hop_size
96
+ self.df_bins = df_bins
97
+ self.unit_norm_init_start = unit_norm_init_start
98
+ self.unit_norm_init_end = unit_norm_init_end
99
+ self.norm_tau = norm_tau
100
+
101
+ self.alpha = self.get_norm_alpha(sample_rate, hop_size, norm_tau)
102
+
103
+ def make_spec_norm_state(self) -> torch.Tensor:
104
+ state = torch.linspace(start=self.unit_norm_init_start, end=self.unit_norm_init_end,
105
+ steps=self.df_bins)
106
+ state = state.unsqueeze(0).unsqueeze(0)
107
+ # state shape: [b, c, df_bins]
108
+ # state shape: [1, 1, df_bins]
109
+ return state
110
+
111
+ def norm(self,
112
+ feat_spec: torch.Tensor,
113
+ state: torch.Tensor = None,
114
+ ):
115
+ feat_spec = feat_spec.clone()
116
+ b, c, t, f = feat_spec.shape
117
+
118
+ # feat_spec shape: [b, 2, t, df_bins]
119
+ if state is None:
120
+ state = self.make_spec_norm_state()
121
+ state = state.clone()
122
+
123
+ for j in range(t):
124
+ current = feat_spec[:, :, j, :]
125
+ current_abs = torch.sum(torch.square(current), dim=1, keepdim=True)
126
+ # current_abs shape: [b, 1, df_bins]
127
+ new_state = current_abs * (1 - self.alpha) + state * self.alpha
128
+
129
+ feat_spec[:, :, j, :] = current / torch.sqrt(new_state)
130
+ state = new_state
131
+
132
+ return feat_spec, state
133
 
134
 
135
  MEAN_NORM_INIT = [-60., -90.]
 
197
  return spec_feat
198
 
199
 
 
 
 
 
 
200
  if __name__ == "__main__":
201
  pass