HoneyTian commited on
Commit
6512ccb
·
1 Parent(s): f16472f
examples/spectrum_unet_irm_aishell/run.sh CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name fi
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
- sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
 
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
+ sh run.sh --stage 2 --stop_stage 3 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
examples/spectrum_unet_irm_aishell/step_2_train_model.py CHANGED
@@ -140,8 +140,10 @@ class CollateFunction(object):
140
  # shape: [freq_dim, time_steps]
141
 
142
  snr_db: torch.Tensor = 10 * torch.log10(
143
- speech_spec / (noise_spec + self.epsilon) + self.epsilon
144
  )
 
 
145
  snr_db_ = torch.unsqueeze(snr_db, dim=0)
146
  snr_db_ = torch.unsqueeze(snr_db_, dim=0)
147
  snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
@@ -301,7 +303,7 @@ def main():
301
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
302
  if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
303
  raise AssertionError("nan or inf in snr_loss")
304
- loss = irm_loss + 1.0 * snr_loss
305
  # loss = irm_loss
306
 
307
  total_loss += loss.item()
@@ -343,7 +345,7 @@ def main():
343
  if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
344
  raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
345
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
346
- loss = irm_loss + 1.0 * snr_loss
347
  # loss = irm_loss
348
 
349
  total_loss += loss.item()
 
140
  # shape: [freq_dim, time_steps]
141
 
142
  snr_db: torch.Tensor = 10 * torch.log10(
143
+ speech_spec / (noise_spec + self.epsilon)
144
  )
145
+ snr_db = torch.clamp(snr_db, min=self.epsilon)
146
+
147
  snr_db_ = torch.unsqueeze(snr_db, dim=0)
148
  snr_db_ = torch.unsqueeze(snr_db_, dim=0)
149
  snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
 
303
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
304
  if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
305
  raise AssertionError("nan or inf in snr_loss")
306
+ loss = irm_loss + 0.1 * snr_loss
307
  # loss = irm_loss
308
 
309
  total_loss += loss.item()
 
345
  if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
346
  raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
347
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
348
+ loss = irm_loss + 0.1 * snr_loss
349
  # loss = irm_loss
350
 
351
  total_loss += loss.item()
requirements-python-3-9-9.txt CHANGED
@@ -8,4 +8,6 @@ openpyxl==3.1.5
8
  torch==2.5.1
9
  torchaudio==2.5.1
10
  overrides==7.7.0
11
- torch-pesq==0.1.2
 
 
 
8
  torch==2.5.1
9
  torchaudio==2.5.1
10
  overrides==7.7.0
11
+ torch-pesq
12
+ torchmetrics
13
+ torchmetrics[audio]
requirements.txt CHANGED
@@ -8,4 +8,6 @@ openpyxl==3.1.5
8
  torch==2.5.1
9
  torchaudio==2.5.1
10
  overrides==7.7.0
11
- torch-pesq
 
 
 
8
  torch==2.5.1
9
  torchaudio==2.5.1
10
  overrides==7.7.0
11
+ torch-pesq==0.1.2
12
+ torchmetrics==1.6.1
13
+ torchmetrics[audio]
toolbox/torch/training/metrics/stoi.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+ from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility
5
+
6
+
7
+ # 假设 reference 和 degraded 是两个音频信号的张量
8
+ reference = torch.randn(1, 16000) # 参考信号
9
+ degraded = torch.randn(1, 16000) # 降质信号
10
+
11
+
12
+ # 计算 STOI 分数
13
+ stoi_score = short_time_objective_intelligibility(reference, degraded, fs=16000)
14
+
15
+ print(f"STOI 分数: {stoi_score}")
16
+
17
+
18
+ if __name__ == '__main__':
19
+ pass
toolbox/torchaudio/models/spectrum_unet_irm/modeling_spectrum_unet_irm.py CHANGED
@@ -514,12 +514,14 @@ class Decoder(nn.Module):
514
 
515
 
516
  class SpectrumUnetIRM(nn.Module):
517
- def __init__(self, config: SpectrumUnetIRMConfig):
518
  super(SpectrumUnetIRM, self).__init__()
519
  self.config = config
520
  self.encoder = Encoder(config)
521
  self.decoder = Decoder(config)
522
 
 
 
523
  def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
524
  """
525
  总体上来说, 它会将 mask 中的值都调大一点. 可能是为了保留更多的声音以免损伤音质, 因为预测的 mask 肯定不是特别正确.
 
514
 
515
 
516
  class SpectrumUnetIRM(nn.Module):
517
+ def __init__(self, config: SpectrumUnetIRMConfig, eps: float = 1e-8):
518
  super(SpectrumUnetIRM, self).__init__()
519
  self.config = config
520
  self.encoder = Encoder(config)
521
  self.decoder = Decoder(config)
522
 
523
+ self.eps = eps
524
+
525
  def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
526
  """
527
  总体上来说, 它会将 mask 中的值都调大一点. 可能是为了保留更多的声音以免损伤音质, 因为预测的 mask 肯定不是特别正确.