HoneyTian commited on
Commit
f16472f
·
1 Parent(s): 9b0f144
examples/spectrum_unet_irm_aishell/run.sh CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name fi
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
 
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
+ sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
examples/spectrum_unet_irm_aishell/step_2_train_model.py CHANGED
@@ -295,10 +295,13 @@ def main():
295
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
296
  raise AssertionError("nan or inf in lsnr_prediction")
297
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
 
 
 
298
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
299
  if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
300
  raise AssertionError("nan or inf in snr_loss")
301
- loss = irm_loss + 0.1 * snr_loss
302
  # loss = irm_loss
303
 
304
  total_loss += loss.item()
@@ -336,8 +339,11 @@ def main():
336
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
337
  raise AssertionError("nan or inf in lsnr_prediction")
338
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
 
 
 
339
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
340
- loss = irm_loss + 0.1 * snr_loss
341
  # loss = irm_loss
342
 
343
  total_loss += loss.item()
 
295
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
296
  raise AssertionError("nan or inf in lsnr_prediction")
297
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
298
+ lsnr_prediction = (lsnr_prediction - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
299
+ if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
300
+ raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
301
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
302
  if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
303
  raise AssertionError("nan or inf in snr_loss")
304
+ loss = irm_loss + 1.0 * snr_loss
305
  # loss = irm_loss
306
 
307
  total_loss += loss.item()
 
339
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
340
  raise AssertionError("nan or inf in lsnr_prediction")
341
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
342
+ lsnr_prediction = (lsnr_prediction - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
343
+ if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
344
+ raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
345
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
346
+ loss = irm_loss + 1.0 * snr_loss
347
  # loss = irm_loss
348
 
349
  total_loss += loss.item()
examples/spectrum_unet_irm_aishell/yaml/config.yaml CHANGED
@@ -33,3 +33,6 @@ decoder_emb_num_layers: 3
33
  decoder_emb_skip_op: "none"
34
  decoder_emb_linear_groups: 16
35
  decoder_emb_hidden_size: 256
 
 
 
 
33
  decoder_emb_skip_op: "none"
34
  decoder_emb_linear_groups: 16
35
  decoder_emb_hidden_size: 256
36
+
37
+ # runtime
38
+ use_post_filter: true
examples/test.py DELETED
@@ -1,18 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import torch
4
-
5
- speech_spec = torch.tensor([0], dtype=torch.float32)
6
- noise_spec = torch.tensor([0], dtype=torch.float32)
7
- epsilon = 1e-8
8
-
9
-
10
- result = torch.log10(
11
- speech_spec / (noise_spec + epsilon) + epsilon
12
- )
13
-
14
- print(result)
15
-
16
-
17
- if __name__ == '__main__':
18
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements-python-3-9-9.txt CHANGED
@@ -8,3 +8,4 @@ openpyxl==3.1.5
8
  torch==2.5.1
9
  torchaudio==2.5.1
10
  overrides==7.7.0
 
 
8
  torch==2.5.1
9
  torchaudio==2.5.1
10
  overrides==7.7.0
11
+ torch-pesq==0.1.2
requirements.txt CHANGED
@@ -8,3 +8,4 @@ openpyxl==3.1.5
8
  torch==2.5.1
9
  torchaudio==2.5.1
10
  overrides==7.7.0
 
 
8
  torch==2.5.1
9
  torchaudio==2.5.1
10
  overrides==7.7.0
11
+ torch-pesq
toolbox/torch/training/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torch/training/metrics/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torch/training/metrics/pesq.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from torch_pesq import PesqLoss
7
+
8
+
9
+ class Pesq(object):
10
+ def __init__(self):
11
+ pass
12
+
13
+
14
+ class CategoricalAccuracy(object):
15
+ def __init__(self, top_k: int = 1, tie_break: bool = False) -> None:
16
+ if top_k > 1 and tie_break:
17
+ raise AssertionError("Tie break in Categorical Accuracy "
18
+ "can be done only for maximum (top_k = 1)")
19
+ if top_k <= 0:
20
+ raise AssertionError("top_k passed to Categorical Accuracy must be > 0")
21
+ self._top_k = top_k
22
+ self._tie_break = tie_break
23
+ self.correct_count = 0.
24
+ self.total_count = 0.
25
+
26
+ def __call__(self,
27
+ predictions: torch.Tensor,
28
+ gold_labels: torch.Tensor,
29
+ mask: Optional[torch.Tensor] = None):
30
+
31
+ # predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask)
32
+
33
+ # Some sanity checks.
34
+ num_classes = predictions.size(-1)
35
+ if gold_labels.dim() != predictions.dim() - 1:
36
+ raise AssertionError("gold_labels must have dimension == predictions.size() - 1 but "
37
+ "found tensor of shape: {}".format(predictions.size()))
38
+ if (gold_labels >= num_classes).any():
39
+ raise AssertionError("A gold label passed to Categorical Accuracy contains an id >= {}, "
40
+ "the number of classes.".format(num_classes))
41
+
42
+ predictions = predictions.view((-1, num_classes))
43
+ gold_labels = gold_labels.view(-1).long()
44
+ if not self._tie_break:
45
+ # Top K indexes of the predictions (or fewer, if there aren't K of them).
46
+ # Special case topk == 1, because it's common and .max() is much faster than .topk().
47
+ if self._top_k == 1:
48
+ top_k = predictions.max(-1)[1].unsqueeze(-1)
49
+ else:
50
+ top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1]
51
+
52
+ # This is of shape (batch_size, ..., top_k).
53
+ correct = top_k.eq(gold_labels.unsqueeze(-1)).float()
54
+ else:
55
+ # prediction is correct if gold label falls on any of the max scores. distribute score by tie_counts
56
+ max_predictions = predictions.max(-1)[0]
57
+ max_predictions_mask = predictions.eq(max_predictions.unsqueeze(-1))
58
+ # max_predictions_mask is (rows X num_classes) and gold_labels is (batch_size)
59
+ # ith entry in gold_labels points to index (0-num_classes) for ith row in max_predictions
60
+ # For each row check if index pointed by gold_label is was 1 or not (among max scored classes)
61
+ correct = max_predictions_mask[torch.arange(gold_labels.numel()).long(), gold_labels].float()
62
+ tie_counts = max_predictions_mask.sum(-1)
63
+ correct /= tie_counts.float()
64
+ correct.unsqueeze_(-1)
65
+
66
+ if mask is not None:
67
+ correct *= mask.view(-1, 1).float()
68
+ self.total_count += mask.sum()
69
+ else:
70
+ self.total_count += gold_labels.numel()
71
+ self.correct_count += correct.sum()
72
+
73
+ def get_metric(self, reset: bool = False):
74
+ """
75
+ Returns
76
+ -------
77
+ The accumulated accuracy.
78
+ """
79
+ if self.total_count > 1e-12:
80
+ accuracy = float(self.correct_count) / float(self.total_count)
81
+ else:
82
+ accuracy = 0.0
83
+ if reset:
84
+ self.reset()
85
+ return {'accuracy': accuracy}
86
+
87
+ def reset(self):
88
+ self.correct_count = 0.0
89
+ self.total_count = 0.0
90
+
91
+
92
+ def main():
93
+ pesq = PesqLoss(0.5,
94
+ sample_rate=8000,
95
+ )
96
+
97
+ reference = torch.randn(1, 44100)
98
+ degraded = torch.randn(1, 44100)
99
+
100
+ mos = pesq.mos(reference, degraded)
101
+ loss = pesq(reference, degraded)
102
+
103
+ print(mos, loss)
104
+ return
105
+
106
+
107
+ if __name__ == '__main__':
108
+ main()
toolbox/torchaudio/models/spectrum_unet_irm/configuration_specturm_unet_irm.py CHANGED
@@ -33,6 +33,7 @@ class SpectrumUnetIRMConfig(PretrainedConfig):
33
  decoder_emb_linear_groups: int = 16,
34
  decoder_emb_hidden_size: int = 256,
35
 
 
36
  **kwargs
37
  ):
38
  super(SpectrumUnetIRMConfig, self).__init__(**kwargs)
@@ -67,6 +68,9 @@ class SpectrumUnetIRMConfig(PretrainedConfig):
67
  self.decoder_emb_linear_groups = decoder_emb_linear_groups
68
  self.decoder_emb_hidden_size = decoder_emb_hidden_size
69
 
 
 
 
70
 
71
  if __name__ == "__main__":
72
  pass
 
33
  decoder_emb_linear_groups: int = 16,
34
  decoder_emb_hidden_size: int = 256,
35
 
36
+ use_post_filter: bool = False,
37
  **kwargs
38
  ):
39
  super(SpectrumUnetIRMConfig, self).__init__(**kwargs)
 
68
  self.decoder_emb_linear_groups = decoder_emb_linear_groups
69
  self.decoder_emb_hidden_size = decoder_emb_hidden_size
70
 
71
+ # runtime
72
+ self.use_post_filter = use_post_filter
73
+
74
 
75
  if __name__ == "__main__":
76
  pass
toolbox/torchaudio/models/spectrum_unet_irm/modeling_spectrum_unet_irm.py CHANGED
@@ -570,6 +570,9 @@ class SpectrumUnetIRM(nn.Module):
570
  mask = torch.transpose(mask, dim0=2, dim1=1)
571
  lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
572
 
 
 
 
573
  # mask shape: [batch_size, freq_dim, time_steps]
574
  # lsnr shape: [batch_size, 1, time_steps]
575
  return mask, lsnr
 
570
  mask = torch.transpose(mask, dim0=2, dim1=1)
571
  lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
572
 
573
+ if not self.training and self.config.use_post_filter:
574
+ mask = self.post_filter(mask)
575
+
576
  # mask shape: [batch_size, freq_dim, time_steps]
577
  # lsnr shape: [batch_size, 1, time_steps]
578
  return mask, lsnr
toolbox/torchaudio/models/spectrum_unet_irm/yaml/config.yaml CHANGED
@@ -33,3 +33,6 @@ decoder_emb_num_layers: 3
33
  decoder_emb_skip_op: "none"
34
  decoder_emb_linear_groups: 16
35
  decoder_emb_hidden_size: 256
 
 
 
 
33
  decoder_emb_skip_op: "none"
34
  decoder_emb_linear_groups: 16
35
  decoder_emb_hidden_size: 256
36
+
37
+ # runtime
38
+ use_post_filter: true