HoneyTian commited on
Commit
365fc03
·
1 Parent(s): 6de113d
examples/nx_clean_unet/run.sh CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name fi
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir2 --final_model_name nx-clean-unet-aishell-20250228 \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
18
  --max_epochs 100
 
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name nx-clean-unet-aishell-20250228 \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
18
  --max_epochs 100
examples/nx_clean_unet/step_3_evaluation.py CHANGED
@@ -1,6 +1,59 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  if __name__ == '__main__':
6
- pass
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ import sys
8
+ import uuid
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ import pandas as pd
16
+ from scipy.io import wavfile
17
+ import torch
18
+ import torch.nn as nn
19
+ import torchaudio
20
+ from tqdm import tqdm
21
+
22
+ from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
23
+ from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel
24
+ from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
25
+
26
+
27
+ def get_args():
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
30
+ parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
31
+ parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
32
+
33
+ parser.add_argument("--limit", default=10, type=int)
34
+
35
+ args = parser.parse_args()
36
+ return args
37
+
38
+
39
+ def logging_config():
40
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
41
+
42
+ logging.basicConfig(format=fmt,
43
+ datefmt="%m/%d/%Y %H:%M:%S",
44
+ level=logging.INFO)
45
+ stream_handler = logging.StreamHandler()
46
+ stream_handler.setLevel(logging.INFO)
47
+ stream_handler.setFormatter(logging.Formatter(fmt))
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+ return logger
52
+
53
+
54
+ def main():
55
+ return
56
 
57
 
58
  if __name__ == '__main__':
59
+ main()
examples/nx_clean_unet/yaml/config.yaml CHANGED
@@ -12,13 +12,13 @@ down_sampling_hidden_channels: 64
12
  down_sampling_kernel_size: 4
13
  down_sampling_stride: 2
14
 
15
- tsfm_hidden_size: 256
16
  tsfm_attention_heads: 4
17
  tsfm_num_blocks: 6
18
  tsfm_dropout_rate: 0.1
19
- tsfm_max_length: 1024
20
- tsfm_chunk_size: 1
21
- tsfm_num_left_chunks: 128
22
 
23
  discriminator_dim: 32
24
  discriminator_in_channel: 2
 
12
  down_sampling_kernel_size: 4
13
  down_sampling_stride: 2
14
 
15
+ tsfm_hidden_size: 64
16
  tsfm_attention_heads: 4
17
  tsfm_num_blocks: 6
18
  tsfm_dropout_rate: 0.1
19
+ tsfm_max_length: 5120
20
+ tsfm_chunk_size: 4
21
+ tsfm_num_left_chunks: 64
22
 
23
  discriminator_dim: 32
24
  discriminator_in_channel: 2
toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav ADDED
Binary file (417 kB). View file
 
toolbox/torchaudio/models/nx_clean_unet/inference_nx_clean_unet.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import logging
4
+ from pathlib import Path
5
+ import shutil
6
+ import tempfile
7
+ import zipfile
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+
14
+ from project_settings import project_path
15
+ from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
16
+ from toolbox.torchaudio.models.nx_clean_unet.modeling_nx_clean_unet import NXCleanUNetPretrainedModel, MODEL_FILE
17
+
18
+ logger = logging.getLogger("toolbox")
19
+
20
+
21
+ class InferenceNXCleanUNet(object):
22
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
23
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
24
+ self.device = torch.device(device)
25
+
26
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
27
+ config, model = self.load_models(self.pretrained_model_path_or_zip_file)
28
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
29
+
30
+ self.config = config
31
+ self.model = model
32
+ self.model.to(device)
33
+ self.model.eval()
34
+
35
+ def load_models(self, model_path: str):
36
+ model_path = Path(model_path)
37
+ if model_path.name.endswith(".zip"):
38
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
39
+ out_root = Path(tempfile.gettempdir()) / "nx_denoise"
40
+ out_root.mkdir(parents=True, exist_ok=True)
41
+ f_zip.extractall(path=out_root)
42
+ model_path = out_root / model_path.stem
43
+
44
+ config = NXCleanUNetConfig.from_pretrained(
45
+ pretrained_model_name_or_path=model_path.as_posix(),
46
+ )
47
+ model = NXCleanUNetPretrainedModel.from_pretrained(
48
+ pretrained_model_name_or_path=model_path.as_posix(),
49
+ )
50
+ model.to(self.device)
51
+ model.eval()
52
+
53
+ shutil.rmtree(model_path)
54
+ return config, model
55
+
56
+ def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
57
+ if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
58
+ raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
59
+
60
+ # noisy_audio shape: [batch_size, num_samples]
61
+ noisy_audios = noisy_audio.to(self.device)
62
+
63
+ with torch.no_grad():
64
+ enhanced_audios = self.model.forward_chunk_by_chunk(noisy_audios)
65
+ # enhanced_audio shape: [batch_size, n_samples]
66
+ # enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
67
+
68
+ enhanced_audio = enhanced_audios[0]
69
+ # enhanced_audio shape: [num_samples,]
70
+ return enhanced_audio
71
+
72
+ def main():
73
+ model_zip_file = project_path / "trained_models/nx-clean-unet-44-epoch.zip"
74
+ infer_nx_clean_unet = InferenceNXCleanUNet(model_zip_file)
75
+
76
+ sample_rate = 8000
77
+ noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_3.wav"
78
+ noisy_audio, _ = librosa.load(
79
+ noisy_audio_file.as_posix(),
80
+ sr=sample_rate,
81
+ )
82
+ # noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
83
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
84
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
85
+
86
+ enhanced_audio = infer_nx_clean_unet.enhancement_by_tensor(noisy_audio)
87
+
88
+ filename = "enhanced_audio.wav"
89
+ torchaudio.save(filename, enhanced_audio.detach().cpu().unsqueeze(dim=0), sample_rate)
90
+
91
+ return
92
+
93
+
94
+ if __name__ == '__main__':
95
+ main()
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py CHANGED
@@ -213,9 +213,47 @@ class NXCleanUNet(nn.Module):
213
  # enhanced_audios shape: [batch_size, 1, n_samples]
214
 
215
  enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
 
216
 
217
  return enhanced_audios
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  MODEL_FILE = "generator.pt"
221
 
 
213
  # enhanced_audios shape: [batch_size, 1, n_samples]
214
 
215
  enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
216
+ # enhanced_audios shape: [batch_size, n_samples]
217
 
218
  return enhanced_audios
219
 
220
+ def forward_chunk_by_chunk(self, noisy_audios: torch.Tensor):
221
+ # noisy_audios shape: [batch_size, n_samples]
222
+ noisy_audios = torch.unsqueeze(noisy_audios, dim=1)
223
+ # noisy_audios shape: [batch_size, 1, n_samples]
224
+
225
+ n_samples = noisy_audios.shape[-1]
226
+ padded_length = get_padding_length(
227
+ n_samples,
228
+ num_layers=self.config.down_sampling_num_layers,
229
+ kernel_size=self.config.down_sampling_kernel_size,
230
+ stride=self.config.down_sampling_stride,
231
+ )
232
+ noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
233
+
234
+ bottle_neck = self.down_sampling.forward(noisy_audios_padded)
235
+ # bottle_neck shape: [batch_size, channels, time_steps]
236
+
237
+ bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
238
+ # bottle_neck shape: [batch_size, time_steps, input_size]
239
+
240
+ bottle_neck = self.transformer.forward_chunk_by_chunk(bottle_neck)
241
+ # bottle_neck shape: [batch_size, time_steps, input_size]
242
+
243
+ bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
244
+ # bottle_neck shape: [batch_size, channels, time_steps]
245
+
246
+ enhanced_audios = self.up_sampling.forward(bottle_neck)
247
+
248
+ enhanced_audios = enhanced_audios[:, :, :n_samples]
249
+ # enhanced_audios shape: [batch_size, 1, n_samples]
250
+
251
+ enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
252
+ # enhanced_audios shape: [batch_size, n_samples]
253
+
254
+ return enhanced_audios
255
+
256
+
257
 
258
  MODEL_FILE = "generator.pt"
259
 
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py CHANGED
@@ -509,13 +509,14 @@ class TransformerEncoder(nn.Module):
509
  # position_embedding shape: [1, time_steps, hidden_size]
510
 
511
  r_att_cache = []
512
- for encoder_layer in self.encoder_layer_list:
513
  xs, new_att_cache = encoder_layer.forward(
514
  x=xs, mask=attention_mask,
515
  position_embedding=position_embedding,
516
- attention_cache=attention_cache,
517
  )
518
  r_att_cache.append(new_att_cache[:, :, self.chunk_size:, :])
 
519
 
520
  r_att_cache = torch.cat(r_att_cache, dim=0)
521
 
@@ -528,8 +529,9 @@ class TransformerEncoder(nn.Module):
528
 
529
  batch_size, time_steps, _ = xs.shape
530
 
531
- offset = 0
532
- attention_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
 
533
  attention_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
534
 
535
  outputs = []
@@ -538,15 +540,15 @@ class TransformerEncoder(nn.Module):
538
  end = begin + self.chunk_size
539
  chunk_xs = xs[:, begin:end, :]
540
 
541
- ys, att_cache = self.forward_chunk(
542
  xs=chunk_xs, attention_mask=attention_mask,
543
- offset=offset, attention_cache=attention_cache
544
  )
 
545
  # xs shape: [batch_size, chunk_size, hidden_size]
546
  ys = self.output_linear.forward(ys)
547
  # xs shape: [batch_size, chunk_size, input_size]
548
 
549
- offset += self.chunk_size
550
  outputs.append(ys)
551
 
552
  ys = torch.cat(outputs, 1)
 
509
  # position_embedding shape: [1, time_steps, hidden_size]
510
 
511
  r_att_cache = []
512
+ for idx, encoder_layer in enumerate(self.encoder_layer_list):
513
  xs, new_att_cache = encoder_layer.forward(
514
  x=xs, mask=attention_mask,
515
  position_embedding=position_embedding,
516
+ attention_cache=attention_cache[idx: idx+1],
517
  )
518
  r_att_cache.append(new_att_cache[:, :, self.chunk_size:, :])
519
+ # r_att_cache.append(new_att_cache)
520
 
521
  r_att_cache = torch.cat(r_att_cache, dim=0)
522
 
 
529
 
530
  batch_size, time_steps, _ = xs.shape
531
 
532
+ # [num_blocks, attention_heads, num_left_chunks, dim]
533
+ # attention_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
534
+ attention_cache: torch.Tensor = torch.zeros((6, 8, 128, 256), device=xs.device)
535
  attention_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
536
 
537
  outputs = []
 
540
  end = begin + self.chunk_size
541
  chunk_xs = xs[:, begin:end, :]
542
 
543
+ ys, attention_cache = self.forward_chunk(
544
  xs=chunk_xs, attention_mask=attention_mask,
545
+ offset=0, attention_cache=attention_cache
546
  )
547
+
548
  # xs shape: [batch_size, chunk_size, hidden_size]
549
  ys = self.output_linear.forward(ys)
550
  # xs shape: [batch_size, chunk_size, input_size]
551
 
 
552
  outputs.append(ys)
553
 
554
  ys = torch.cat(outputs, 1)
toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml CHANGED
@@ -6,21 +6,29 @@ n_fft: 512
6
  win_size: 200
7
  hop_size: 80
8
 
 
 
 
 
 
 
 
 
9
  down_sampling_num_layers: 5
10
  down_sampling_in_channels: 1
11
  down_sampling_hidden_channels: 64
12
  down_sampling_kernel_size: 4
13
  down_sampling_stride: 2
14
 
15
- tsfm_hidden_size: 1024
16
- tsfm_attention_heads: 8
17
  tsfm_num_blocks: 6
18
  tsfm_dropout_rate: 0.1
19
- tsfm_max_length: 1024
20
- tsfm_chunk_size: 1
21
- tsfm_num_left_chunks: 128
22
 
23
- discriminator_dim: 16
24
  discriminator_in_channel: 2
25
 
26
  compress_factor: 0.3
 
6
  win_size: 200
7
  hop_size: 80
8
 
9
+ # 2**down_sampling_num_layers,
10
+ # 例如 2**5=32 就意味着 32个值在降采样之后是一个时间步,
11
+ # 则一步是 32/sample_rate = 0.004秒。
12
+ # 那么 tsfm_chunk_size=4 则为16ms,tsfm_chunk_size=8 则为32ms
13
+ # 假设每次向左看1秒,则:
14
+ # tsfm_chunk_size=1,tsfm_num_left_chunks: 256
15
+ # tsfm_chunk_size=4,tsfm_num_left_chunks: 64
16
+ # tsfm_chunk_size=8,tsfm_num_left_chunks: 32
17
  down_sampling_num_layers: 5
18
  down_sampling_in_channels: 1
19
  down_sampling_hidden_channels: 64
20
  down_sampling_kernel_size: 4
21
  down_sampling_stride: 2
22
 
23
+ tsfm_hidden_size: 64
24
+ tsfm_attention_heads: 4
25
  tsfm_num_blocks: 6
26
  tsfm_dropout_rate: 0.1
27
+ tsfm_max_length: 5120
28
+ tsfm_chunk_size: 4
29
+ tsfm_num_left_chunks: 64
30
 
31
+ discriminator_dim: 32
32
  discriminator_in_channel: 2
33
 
34
  compress_factor: 0.3