Spaces:
Running
Running
update
Browse files- examples/spectrum_unet_irm_aishell/run.sh +1 -1
- examples/spectrum_unet_irm_aishell/step_2_train_model.py +8 -2
- examples/spectrum_unet_irm_aishell/yaml/config.yaml +3 -0
- examples/test.py +0 -18
- requirements-python-3-9-9.txt +1 -0
- requirements.txt +1 -0
- toolbox/torch/training/__init__.py +6 -0
- toolbox/torch/training/metrics/__init__.py +6 -0
- toolbox/torch/training/metrics/pesq.py +108 -0
- toolbox/torchaudio/models/spectrum_unet_irm/configuration_specturm_unet_irm.py +4 -0
- toolbox/torchaudio/models/spectrum_unet_irm/modeling_spectrum_unet_irm.py +3 -0
- toolbox/torchaudio/models/spectrum_unet_irm/yaml/config.yaml +3 -0
examples/spectrum_unet_irm_aishell/run.sh
CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name fi
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
-
sh run.sh --stage
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
+
sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
examples/spectrum_unet_irm_aishell/step_2_train_model.py
CHANGED
@@ -295,10 +295,13 @@ def main():
|
|
295 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
296 |
raise AssertionError("nan or inf in lsnr_prediction")
|
297 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
|
|
|
|
|
|
298 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
299 |
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
300 |
raise AssertionError("nan or inf in snr_loss")
|
301 |
-
loss = irm_loss + 0
|
302 |
# loss = irm_loss
|
303 |
|
304 |
total_loss += loss.item()
|
@@ -336,8 +339,11 @@ def main():
|
|
336 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
337 |
raise AssertionError("nan or inf in lsnr_prediction")
|
338 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
|
|
|
|
|
|
339 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
340 |
-
loss = irm_loss + 0
|
341 |
# loss = irm_loss
|
342 |
|
343 |
total_loss += loss.item()
|
|
|
295 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
296 |
raise AssertionError("nan or inf in lsnr_prediction")
|
297 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
298 |
+
lsnr_prediction = (lsnr_prediction - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
|
299 |
+
if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
|
300 |
+
raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
|
301 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
302 |
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
303 |
raise AssertionError("nan or inf in snr_loss")
|
304 |
+
loss = irm_loss + 1.0 * snr_loss
|
305 |
# loss = irm_loss
|
306 |
|
307 |
total_loss += loss.item()
|
|
|
339 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
340 |
raise AssertionError("nan or inf in lsnr_prediction")
|
341 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
342 |
+
lsnr_prediction = (lsnr_prediction - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
|
343 |
+
if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
|
344 |
+
raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
|
345 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
346 |
+
loss = irm_loss + 1.0 * snr_loss
|
347 |
# loss = irm_loss
|
348 |
|
349 |
total_loss += loss.item()
|
examples/spectrum_unet_irm_aishell/yaml/config.yaml
CHANGED
@@ -33,3 +33,6 @@ decoder_emb_num_layers: 3
|
|
33 |
decoder_emb_skip_op: "none"
|
34 |
decoder_emb_linear_groups: 16
|
35 |
decoder_emb_hidden_size: 256
|
|
|
|
|
|
|
|
33 |
decoder_emb_skip_op: "none"
|
34 |
decoder_emb_linear_groups: 16
|
35 |
decoder_emb_hidden_size: 256
|
36 |
+
|
37 |
+
# runtime
|
38 |
+
use_post_filter: true
|
examples/test.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
speech_spec = torch.tensor([0], dtype=torch.float32)
|
6 |
-
noise_spec = torch.tensor([0], dtype=torch.float32)
|
7 |
-
epsilon = 1e-8
|
8 |
-
|
9 |
-
|
10 |
-
result = torch.log10(
|
11 |
-
speech_spec / (noise_spec + epsilon) + epsilon
|
12 |
-
)
|
13 |
-
|
14 |
-
print(result)
|
15 |
-
|
16 |
-
|
17 |
-
if __name__ == '__main__':
|
18 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements-python-3-9-9.txt
CHANGED
@@ -8,3 +8,4 @@ openpyxl==3.1.5
|
|
8 |
torch==2.5.1
|
9 |
torchaudio==2.5.1
|
10 |
overrides==7.7.0
|
|
|
|
8 |
torch==2.5.1
|
9 |
torchaudio==2.5.1
|
10 |
overrides==7.7.0
|
11 |
+
torch-pesq==0.1.2
|
requirements.txt
CHANGED
@@ -8,3 +8,4 @@ openpyxl==3.1.5
|
|
8 |
torch==2.5.1
|
9 |
torchaudio==2.5.1
|
10 |
overrides==7.7.0
|
|
|
|
8 |
torch==2.5.1
|
9 |
torchaudio==2.5.1
|
10 |
overrides==7.7.0
|
11 |
+
torch-pesq
|
toolbox/torch/training/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/torch/training/metrics/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/torch/training/metrics/pesq.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch_pesq import PesqLoss
|
7 |
+
|
8 |
+
|
9 |
+
class Pesq(object):
|
10 |
+
def __init__(self):
|
11 |
+
pass
|
12 |
+
|
13 |
+
|
14 |
+
class CategoricalAccuracy(object):
|
15 |
+
def __init__(self, top_k: int = 1, tie_break: bool = False) -> None:
|
16 |
+
if top_k > 1 and tie_break:
|
17 |
+
raise AssertionError("Tie break in Categorical Accuracy "
|
18 |
+
"can be done only for maximum (top_k = 1)")
|
19 |
+
if top_k <= 0:
|
20 |
+
raise AssertionError("top_k passed to Categorical Accuracy must be > 0")
|
21 |
+
self._top_k = top_k
|
22 |
+
self._tie_break = tie_break
|
23 |
+
self.correct_count = 0.
|
24 |
+
self.total_count = 0.
|
25 |
+
|
26 |
+
def __call__(self,
|
27 |
+
predictions: torch.Tensor,
|
28 |
+
gold_labels: torch.Tensor,
|
29 |
+
mask: Optional[torch.Tensor] = None):
|
30 |
+
|
31 |
+
# predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask)
|
32 |
+
|
33 |
+
# Some sanity checks.
|
34 |
+
num_classes = predictions.size(-1)
|
35 |
+
if gold_labels.dim() != predictions.dim() - 1:
|
36 |
+
raise AssertionError("gold_labels must have dimension == predictions.size() - 1 but "
|
37 |
+
"found tensor of shape: {}".format(predictions.size()))
|
38 |
+
if (gold_labels >= num_classes).any():
|
39 |
+
raise AssertionError("A gold label passed to Categorical Accuracy contains an id >= {}, "
|
40 |
+
"the number of classes.".format(num_classes))
|
41 |
+
|
42 |
+
predictions = predictions.view((-1, num_classes))
|
43 |
+
gold_labels = gold_labels.view(-1).long()
|
44 |
+
if not self._tie_break:
|
45 |
+
# Top K indexes of the predictions (or fewer, if there aren't K of them).
|
46 |
+
# Special case topk == 1, because it's common and .max() is much faster than .topk().
|
47 |
+
if self._top_k == 1:
|
48 |
+
top_k = predictions.max(-1)[1].unsqueeze(-1)
|
49 |
+
else:
|
50 |
+
top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1]
|
51 |
+
|
52 |
+
# This is of shape (batch_size, ..., top_k).
|
53 |
+
correct = top_k.eq(gold_labels.unsqueeze(-1)).float()
|
54 |
+
else:
|
55 |
+
# prediction is correct if gold label falls on any of the max scores. distribute score by tie_counts
|
56 |
+
max_predictions = predictions.max(-1)[0]
|
57 |
+
max_predictions_mask = predictions.eq(max_predictions.unsqueeze(-1))
|
58 |
+
# max_predictions_mask is (rows X num_classes) and gold_labels is (batch_size)
|
59 |
+
# ith entry in gold_labels points to index (0-num_classes) for ith row in max_predictions
|
60 |
+
# For each row check if index pointed by gold_label is was 1 or not (among max scored classes)
|
61 |
+
correct = max_predictions_mask[torch.arange(gold_labels.numel()).long(), gold_labels].float()
|
62 |
+
tie_counts = max_predictions_mask.sum(-1)
|
63 |
+
correct /= tie_counts.float()
|
64 |
+
correct.unsqueeze_(-1)
|
65 |
+
|
66 |
+
if mask is not None:
|
67 |
+
correct *= mask.view(-1, 1).float()
|
68 |
+
self.total_count += mask.sum()
|
69 |
+
else:
|
70 |
+
self.total_count += gold_labels.numel()
|
71 |
+
self.correct_count += correct.sum()
|
72 |
+
|
73 |
+
def get_metric(self, reset: bool = False):
|
74 |
+
"""
|
75 |
+
Returns
|
76 |
+
-------
|
77 |
+
The accumulated accuracy.
|
78 |
+
"""
|
79 |
+
if self.total_count > 1e-12:
|
80 |
+
accuracy = float(self.correct_count) / float(self.total_count)
|
81 |
+
else:
|
82 |
+
accuracy = 0.0
|
83 |
+
if reset:
|
84 |
+
self.reset()
|
85 |
+
return {'accuracy': accuracy}
|
86 |
+
|
87 |
+
def reset(self):
|
88 |
+
self.correct_count = 0.0
|
89 |
+
self.total_count = 0.0
|
90 |
+
|
91 |
+
|
92 |
+
def main():
|
93 |
+
pesq = PesqLoss(0.5,
|
94 |
+
sample_rate=8000,
|
95 |
+
)
|
96 |
+
|
97 |
+
reference = torch.randn(1, 44100)
|
98 |
+
degraded = torch.randn(1, 44100)
|
99 |
+
|
100 |
+
mos = pesq.mos(reference, degraded)
|
101 |
+
loss = pesq(reference, degraded)
|
102 |
+
|
103 |
+
print(mos, loss)
|
104 |
+
return
|
105 |
+
|
106 |
+
|
107 |
+
if __name__ == '__main__':
|
108 |
+
main()
|
toolbox/torchaudio/models/spectrum_unet_irm/configuration_specturm_unet_irm.py
CHANGED
@@ -33,6 +33,7 @@ class SpectrumUnetIRMConfig(PretrainedConfig):
|
|
33 |
decoder_emb_linear_groups: int = 16,
|
34 |
decoder_emb_hidden_size: int = 256,
|
35 |
|
|
|
36 |
**kwargs
|
37 |
):
|
38 |
super(SpectrumUnetIRMConfig, self).__init__(**kwargs)
|
@@ -67,6 +68,9 @@ class SpectrumUnetIRMConfig(PretrainedConfig):
|
|
67 |
self.decoder_emb_linear_groups = decoder_emb_linear_groups
|
68 |
self.decoder_emb_hidden_size = decoder_emb_hidden_size
|
69 |
|
|
|
|
|
|
|
70 |
|
71 |
if __name__ == "__main__":
|
72 |
pass
|
|
|
33 |
decoder_emb_linear_groups: int = 16,
|
34 |
decoder_emb_hidden_size: int = 256,
|
35 |
|
36 |
+
use_post_filter: bool = False,
|
37 |
**kwargs
|
38 |
):
|
39 |
super(SpectrumUnetIRMConfig, self).__init__(**kwargs)
|
|
|
68 |
self.decoder_emb_linear_groups = decoder_emb_linear_groups
|
69 |
self.decoder_emb_hidden_size = decoder_emb_hidden_size
|
70 |
|
71 |
+
# runtime
|
72 |
+
self.use_post_filter = use_post_filter
|
73 |
+
|
74 |
|
75 |
if __name__ == "__main__":
|
76 |
pass
|
toolbox/torchaudio/models/spectrum_unet_irm/modeling_spectrum_unet_irm.py
CHANGED
@@ -570,6 +570,9 @@ class SpectrumUnetIRM(nn.Module):
|
|
570 |
mask = torch.transpose(mask, dim0=2, dim1=1)
|
571 |
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
572 |
|
|
|
|
|
|
|
573 |
# mask shape: [batch_size, freq_dim, time_steps]
|
574 |
# lsnr shape: [batch_size, 1, time_steps]
|
575 |
return mask, lsnr
|
|
|
570 |
mask = torch.transpose(mask, dim0=2, dim1=1)
|
571 |
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
572 |
|
573 |
+
if not self.training and self.config.use_post_filter:
|
574 |
+
mask = self.post_filter(mask)
|
575 |
+
|
576 |
# mask shape: [batch_size, freq_dim, time_steps]
|
577 |
# lsnr shape: [batch_size, 1, time_steps]
|
578 |
return mask, lsnr
|
toolbox/torchaudio/models/spectrum_unet_irm/yaml/config.yaml
CHANGED
@@ -33,3 +33,6 @@ decoder_emb_num_layers: 3
|
|
33 |
decoder_emb_skip_op: "none"
|
34 |
decoder_emb_linear_groups: 16
|
35 |
decoder_emb_hidden_size: 256
|
|
|
|
|
|
|
|
33 |
decoder_emb_skip_op: "none"
|
34 |
decoder_emb_linear_groups: 16
|
35 |
decoder_emb_hidden_size: 256
|
36 |
+
|
37 |
+
# runtime
|
38 |
+
use_post_filter: true
|