Spaces:
Running
Running
update
Browse files
examples/spectrum_unet_irm_aishell/run.sh
CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name fi
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
-
sh run.sh --stage
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
+
sh run.sh --stage 2 --stop_stage 3 --system_version centos --file_folder_name file_dir \
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
examples/spectrum_unet_irm_aishell/step_2_train_model.py
CHANGED
@@ -140,8 +140,10 @@ class CollateFunction(object):
|
|
140 |
# shape: [freq_dim, time_steps]
|
141 |
|
142 |
snr_db: torch.Tensor = 10 * torch.log10(
|
143 |
-
speech_spec / (noise_spec + self.epsilon)
|
144 |
)
|
|
|
|
|
145 |
snr_db_ = torch.unsqueeze(snr_db, dim=0)
|
146 |
snr_db_ = torch.unsqueeze(snr_db_, dim=0)
|
147 |
snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
|
@@ -301,7 +303,7 @@ def main():
|
|
301 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
302 |
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
303 |
raise AssertionError("nan or inf in snr_loss")
|
304 |
-
loss = irm_loss + 1
|
305 |
# loss = irm_loss
|
306 |
|
307 |
total_loss += loss.item()
|
@@ -343,7 +345,7 @@ def main():
|
|
343 |
if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
|
344 |
raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
|
345 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
346 |
-
loss = irm_loss + 1
|
347 |
# loss = irm_loss
|
348 |
|
349 |
total_loss += loss.item()
|
|
|
140 |
# shape: [freq_dim, time_steps]
|
141 |
|
142 |
snr_db: torch.Tensor = 10 * torch.log10(
|
143 |
+
speech_spec / (noise_spec + self.epsilon)
|
144 |
)
|
145 |
+
snr_db = torch.clamp(snr_db, min=self.epsilon)
|
146 |
+
|
147 |
snr_db_ = torch.unsqueeze(snr_db, dim=0)
|
148 |
snr_db_ = torch.unsqueeze(snr_db_, dim=0)
|
149 |
snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
|
|
|
303 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
304 |
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
305 |
raise AssertionError("nan or inf in snr_loss")
|
306 |
+
loss = irm_loss + 0.1 * snr_loss
|
307 |
# loss = irm_loss
|
308 |
|
309 |
total_loss += loss.item()
|
|
|
345 |
if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
|
346 |
raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
|
347 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
348 |
+
loss = irm_loss + 0.1 * snr_loss
|
349 |
# loss = irm_loss
|
350 |
|
351 |
total_loss += loss.item()
|
requirements-python-3-9-9.txt
CHANGED
@@ -8,4 +8,6 @@ openpyxl==3.1.5
|
|
8 |
torch==2.5.1
|
9 |
torchaudio==2.5.1
|
10 |
overrides==7.7.0
|
11 |
-
torch-pesq
|
|
|
|
|
|
8 |
torch==2.5.1
|
9 |
torchaudio==2.5.1
|
10 |
overrides==7.7.0
|
11 |
+
torch-pesq
|
12 |
+
torchmetrics
|
13 |
+
torchmetrics[audio]
|
requirements.txt
CHANGED
@@ -8,4 +8,6 @@ openpyxl==3.1.5
|
|
8 |
torch==2.5.1
|
9 |
torchaudio==2.5.1
|
10 |
overrides==7.7.0
|
11 |
-
torch-pesq
|
|
|
|
|
|
8 |
torch==2.5.1
|
9 |
torchaudio==2.5.1
|
10 |
overrides==7.7.0
|
11 |
+
torch-pesq==0.1.2
|
12 |
+
torchmetrics==1.6.1
|
13 |
+
torchmetrics[audio]
|
toolbox/torch/training/metrics/stoi.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import torch
|
4 |
+
from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility
|
5 |
+
|
6 |
+
|
7 |
+
# 假设 reference 和 degraded 是两个音频信号的张量
|
8 |
+
reference = torch.randn(1, 16000) # 参考信号
|
9 |
+
degraded = torch.randn(1, 16000) # 降质信号
|
10 |
+
|
11 |
+
|
12 |
+
# 计算 STOI 分数
|
13 |
+
stoi_score = short_time_objective_intelligibility(reference, degraded, fs=16000)
|
14 |
+
|
15 |
+
print(f"STOI 分数: {stoi_score}")
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
pass
|
toolbox/torchaudio/models/spectrum_unet_irm/modeling_spectrum_unet_irm.py
CHANGED
@@ -514,12 +514,14 @@ class Decoder(nn.Module):
|
|
514 |
|
515 |
|
516 |
class SpectrumUnetIRM(nn.Module):
|
517 |
-
def __init__(self, config: SpectrumUnetIRMConfig):
|
518 |
super(SpectrumUnetIRM, self).__init__()
|
519 |
self.config = config
|
520 |
self.encoder = Encoder(config)
|
521 |
self.decoder = Decoder(config)
|
522 |
|
|
|
|
|
523 |
def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
|
524 |
"""
|
525 |
总体上来说, 它会将 mask 中的值都调大一点. 可能是为了保留更多的声音以免损伤音质, 因为预测的 mask 肯定不是特别正确.
|
|
|
514 |
|
515 |
|
516 |
class SpectrumUnetIRM(nn.Module):
|
517 |
+
def __init__(self, config: SpectrumUnetIRMConfig, eps: float = 1e-8):
|
518 |
super(SpectrumUnetIRM, self).__init__()
|
519 |
self.config = config
|
520 |
self.encoder = Encoder(config)
|
521 |
self.decoder = Decoder(config)
|
522 |
|
523 |
+
self.eps = eps
|
524 |
+
|
525 |
def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
|
526 |
"""
|
527 |
总体上来说, 它会将 mask 中的值都调大一点. 可能是为了保留更多的声音以免损伤音质, 因为预测的 mask 肯定不是特别正确.
|