Spaces:
Running
Running
update
Browse files- examples/nx_clean_unet/run.sh +1 -1
- examples/nx_clean_unet/step_3_evaluation.py +54 -1
- examples/nx_clean_unet/yaml/config.yaml +4 -4
- toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav +0 -0
- toolbox/torchaudio/models/nx_clean_unet/inference_nx_clean_unet.py +95 -0
- toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py +38 -0
- toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py +9 -7
- toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml +14 -6
examples/nx_clean_unet/run.sh
CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name fi
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
-
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
|
18 |
--max_epochs 100
|
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name nx-clean-unet-aishell-20250228 \
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
|
18 |
--max_epochs 100
|
examples/nx_clean_unet/step_3_evaluation.py
CHANGED
@@ -1,6 +1,59 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
|
5 |
if __name__ == '__main__':
|
6 |
-
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import sys
|
8 |
+
import uuid
|
9 |
+
|
10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
12 |
+
|
13 |
+
import librosa
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
from scipy.io import wavfile
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torchaudio
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
|
23 |
+
from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel
|
24 |
+
from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
|
25 |
+
|
26 |
+
|
27 |
+
def get_args():
|
28 |
+
parser = argparse.ArgumentParser()
|
29 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
30 |
+
parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
|
31 |
+
parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
|
32 |
+
|
33 |
+
parser.add_argument("--limit", default=10, type=int)
|
34 |
+
|
35 |
+
args = parser.parse_args()
|
36 |
+
return args
|
37 |
+
|
38 |
+
|
39 |
+
def logging_config():
|
40 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
41 |
+
|
42 |
+
logging.basicConfig(format=fmt,
|
43 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
44 |
+
level=logging.INFO)
|
45 |
+
stream_handler = logging.StreamHandler()
|
46 |
+
stream_handler.setLevel(logging.INFO)
|
47 |
+
stream_handler.setFormatter(logging.Formatter(fmt))
|
48 |
+
|
49 |
+
logger = logging.getLogger(__name__)
|
50 |
+
|
51 |
+
return logger
|
52 |
+
|
53 |
+
|
54 |
+
def main():
|
55 |
+
return
|
56 |
|
57 |
|
58 |
if __name__ == '__main__':
|
59 |
+
main()
|
examples/nx_clean_unet/yaml/config.yaml
CHANGED
@@ -12,13 +12,13 @@ down_sampling_hidden_channels: 64
|
|
12 |
down_sampling_kernel_size: 4
|
13 |
down_sampling_stride: 2
|
14 |
|
15 |
-
tsfm_hidden_size:
|
16 |
tsfm_attention_heads: 4
|
17 |
tsfm_num_blocks: 6
|
18 |
tsfm_dropout_rate: 0.1
|
19 |
-
tsfm_max_length:
|
20 |
-
tsfm_chunk_size:
|
21 |
-
tsfm_num_left_chunks:
|
22 |
|
23 |
discriminator_dim: 32
|
24 |
discriminator_in_channel: 2
|
|
|
12 |
down_sampling_kernel_size: 4
|
13 |
down_sampling_stride: 2
|
14 |
|
15 |
+
tsfm_hidden_size: 64
|
16 |
tsfm_attention_heads: 4
|
17 |
tsfm_num_blocks: 6
|
18 |
tsfm_dropout_rate: 0.1
|
19 |
+
tsfm_max_length: 5120
|
20 |
+
tsfm_chunk_size: 4
|
21 |
+
tsfm_num_left_chunks: 64
|
22 |
|
23 |
discriminator_dim: 32
|
24 |
discriminator_in_channel: 2
|
toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav
ADDED
Binary file (417 kB). View file
|
|
toolbox/torchaudio/models/nx_clean_unet/inference_nx_clean_unet.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import logging
|
4 |
+
from pathlib import Path
|
5 |
+
import shutil
|
6 |
+
import tempfile
|
7 |
+
import zipfile
|
8 |
+
|
9 |
+
import librosa
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
from project_settings import project_path
|
15 |
+
from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
|
16 |
+
from toolbox.torchaudio.models.nx_clean_unet.modeling_nx_clean_unet import NXCleanUNetPretrainedModel, MODEL_FILE
|
17 |
+
|
18 |
+
logger = logging.getLogger("toolbox")
|
19 |
+
|
20 |
+
|
21 |
+
class InferenceNXCleanUNet(object):
|
22 |
+
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
23 |
+
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
24 |
+
self.device = torch.device(device)
|
25 |
+
|
26 |
+
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
|
27 |
+
config, model = self.load_models(self.pretrained_model_path_or_zip_file)
|
28 |
+
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
|
29 |
+
|
30 |
+
self.config = config
|
31 |
+
self.model = model
|
32 |
+
self.model.to(device)
|
33 |
+
self.model.eval()
|
34 |
+
|
35 |
+
def load_models(self, model_path: str):
|
36 |
+
model_path = Path(model_path)
|
37 |
+
if model_path.name.endswith(".zip"):
|
38 |
+
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
|
39 |
+
out_root = Path(tempfile.gettempdir()) / "nx_denoise"
|
40 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
41 |
+
f_zip.extractall(path=out_root)
|
42 |
+
model_path = out_root / model_path.stem
|
43 |
+
|
44 |
+
config = NXCleanUNetConfig.from_pretrained(
|
45 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
46 |
+
)
|
47 |
+
model = NXCleanUNetPretrainedModel.from_pretrained(
|
48 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
49 |
+
)
|
50 |
+
model.to(self.device)
|
51 |
+
model.eval()
|
52 |
+
|
53 |
+
shutil.rmtree(model_path)
|
54 |
+
return config, model
|
55 |
+
|
56 |
+
def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
|
57 |
+
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
|
58 |
+
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
|
59 |
+
|
60 |
+
# noisy_audio shape: [batch_size, num_samples]
|
61 |
+
noisy_audios = noisy_audio.to(self.device)
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
enhanced_audios = self.model.forward_chunk_by_chunk(noisy_audios)
|
65 |
+
# enhanced_audio shape: [batch_size, n_samples]
|
66 |
+
# enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
67 |
+
|
68 |
+
enhanced_audio = enhanced_audios[0]
|
69 |
+
# enhanced_audio shape: [num_samples,]
|
70 |
+
return enhanced_audio
|
71 |
+
|
72 |
+
def main():
|
73 |
+
model_zip_file = project_path / "trained_models/nx-clean-unet-44-epoch.zip"
|
74 |
+
infer_nx_clean_unet = InferenceNXCleanUNet(model_zip_file)
|
75 |
+
|
76 |
+
sample_rate = 8000
|
77 |
+
noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_3.wav"
|
78 |
+
noisy_audio, _ = librosa.load(
|
79 |
+
noisy_audio_file.as_posix(),
|
80 |
+
sr=sample_rate,
|
81 |
+
)
|
82 |
+
# noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
|
83 |
+
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
|
84 |
+
noisy_audio = noisy_audio.unsqueeze(dim=0)
|
85 |
+
|
86 |
+
enhanced_audio = infer_nx_clean_unet.enhancement_by_tensor(noisy_audio)
|
87 |
+
|
88 |
+
filename = "enhanced_audio.wav"
|
89 |
+
torchaudio.save(filename, enhanced_audio.detach().cpu().unsqueeze(dim=0), sample_rate)
|
90 |
+
|
91 |
+
return
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == '__main__':
|
95 |
+
main()
|
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py
CHANGED
@@ -213,9 +213,47 @@ class NXCleanUNet(nn.Module):
|
|
213 |
# enhanced_audios shape: [batch_size, 1, n_samples]
|
214 |
|
215 |
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
|
|
216 |
|
217 |
return enhanced_audios
|
218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
MODEL_FILE = "generator.pt"
|
221 |
|
|
|
213 |
# enhanced_audios shape: [batch_size, 1, n_samples]
|
214 |
|
215 |
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
216 |
+
# enhanced_audios shape: [batch_size, n_samples]
|
217 |
|
218 |
return enhanced_audios
|
219 |
|
220 |
+
def forward_chunk_by_chunk(self, noisy_audios: torch.Tensor):
|
221 |
+
# noisy_audios shape: [batch_size, n_samples]
|
222 |
+
noisy_audios = torch.unsqueeze(noisy_audios, dim=1)
|
223 |
+
# noisy_audios shape: [batch_size, 1, n_samples]
|
224 |
+
|
225 |
+
n_samples = noisy_audios.shape[-1]
|
226 |
+
padded_length = get_padding_length(
|
227 |
+
n_samples,
|
228 |
+
num_layers=self.config.down_sampling_num_layers,
|
229 |
+
kernel_size=self.config.down_sampling_kernel_size,
|
230 |
+
stride=self.config.down_sampling_stride,
|
231 |
+
)
|
232 |
+
noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
|
233 |
+
|
234 |
+
bottle_neck = self.down_sampling.forward(noisy_audios_padded)
|
235 |
+
# bottle_neck shape: [batch_size, channels, time_steps]
|
236 |
+
|
237 |
+
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
238 |
+
# bottle_neck shape: [batch_size, time_steps, input_size]
|
239 |
+
|
240 |
+
bottle_neck = self.transformer.forward_chunk_by_chunk(bottle_neck)
|
241 |
+
# bottle_neck shape: [batch_size, time_steps, input_size]
|
242 |
+
|
243 |
+
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
244 |
+
# bottle_neck shape: [batch_size, channels, time_steps]
|
245 |
+
|
246 |
+
enhanced_audios = self.up_sampling.forward(bottle_neck)
|
247 |
+
|
248 |
+
enhanced_audios = enhanced_audios[:, :, :n_samples]
|
249 |
+
# enhanced_audios shape: [batch_size, 1, n_samples]
|
250 |
+
|
251 |
+
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
252 |
+
# enhanced_audios shape: [batch_size, n_samples]
|
253 |
+
|
254 |
+
return enhanced_audios
|
255 |
+
|
256 |
+
|
257 |
|
258 |
MODEL_FILE = "generator.pt"
|
259 |
|
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py
CHANGED
@@ -509,13 +509,14 @@ class TransformerEncoder(nn.Module):
|
|
509 |
# position_embedding shape: [1, time_steps, hidden_size]
|
510 |
|
511 |
r_att_cache = []
|
512 |
-
for encoder_layer in self.encoder_layer_list:
|
513 |
xs, new_att_cache = encoder_layer.forward(
|
514 |
x=xs, mask=attention_mask,
|
515 |
position_embedding=position_embedding,
|
516 |
-
attention_cache=attention_cache,
|
517 |
)
|
518 |
r_att_cache.append(new_att_cache[:, :, self.chunk_size:, :])
|
|
|
519 |
|
520 |
r_att_cache = torch.cat(r_att_cache, dim=0)
|
521 |
|
@@ -528,8 +529,9 @@ class TransformerEncoder(nn.Module):
|
|
528 |
|
529 |
batch_size, time_steps, _ = xs.shape
|
530 |
|
531 |
-
|
532 |
-
attention_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
|
|
|
533 |
attention_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
|
534 |
|
535 |
outputs = []
|
@@ -538,15 +540,15 @@ class TransformerEncoder(nn.Module):
|
|
538 |
end = begin + self.chunk_size
|
539 |
chunk_xs = xs[:, begin:end, :]
|
540 |
|
541 |
-
ys,
|
542 |
xs=chunk_xs, attention_mask=attention_mask,
|
543 |
-
offset=
|
544 |
)
|
|
|
545 |
# xs shape: [batch_size, chunk_size, hidden_size]
|
546 |
ys = self.output_linear.forward(ys)
|
547 |
# xs shape: [batch_size, chunk_size, input_size]
|
548 |
|
549 |
-
offset += self.chunk_size
|
550 |
outputs.append(ys)
|
551 |
|
552 |
ys = torch.cat(outputs, 1)
|
|
|
509 |
# position_embedding shape: [1, time_steps, hidden_size]
|
510 |
|
511 |
r_att_cache = []
|
512 |
+
for idx, encoder_layer in enumerate(self.encoder_layer_list):
|
513 |
xs, new_att_cache = encoder_layer.forward(
|
514 |
x=xs, mask=attention_mask,
|
515 |
position_embedding=position_embedding,
|
516 |
+
attention_cache=attention_cache[idx: idx+1],
|
517 |
)
|
518 |
r_att_cache.append(new_att_cache[:, :, self.chunk_size:, :])
|
519 |
+
# r_att_cache.append(new_att_cache)
|
520 |
|
521 |
r_att_cache = torch.cat(r_att_cache, dim=0)
|
522 |
|
|
|
529 |
|
530 |
batch_size, time_steps, _ = xs.shape
|
531 |
|
532 |
+
# [num_blocks, attention_heads, num_left_chunks, dim]
|
533 |
+
# attention_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
|
534 |
+
attention_cache: torch.Tensor = torch.zeros((6, 8, 128, 256), device=xs.device)
|
535 |
attention_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
|
536 |
|
537 |
outputs = []
|
|
|
540 |
end = begin + self.chunk_size
|
541 |
chunk_xs = xs[:, begin:end, :]
|
542 |
|
543 |
+
ys, attention_cache = self.forward_chunk(
|
544 |
xs=chunk_xs, attention_mask=attention_mask,
|
545 |
+
offset=0, attention_cache=attention_cache
|
546 |
)
|
547 |
+
|
548 |
# xs shape: [batch_size, chunk_size, hidden_size]
|
549 |
ys = self.output_linear.forward(ys)
|
550 |
# xs shape: [batch_size, chunk_size, input_size]
|
551 |
|
|
|
552 |
outputs.append(ys)
|
553 |
|
554 |
ys = torch.cat(outputs, 1)
|
toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml
CHANGED
@@ -6,21 +6,29 @@ n_fft: 512
|
|
6 |
win_size: 200
|
7 |
hop_size: 80
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
down_sampling_num_layers: 5
|
10 |
down_sampling_in_channels: 1
|
11 |
down_sampling_hidden_channels: 64
|
12 |
down_sampling_kernel_size: 4
|
13 |
down_sampling_stride: 2
|
14 |
|
15 |
-
tsfm_hidden_size:
|
16 |
-
tsfm_attention_heads:
|
17 |
tsfm_num_blocks: 6
|
18 |
tsfm_dropout_rate: 0.1
|
19 |
-
tsfm_max_length:
|
20 |
-
tsfm_chunk_size:
|
21 |
-
tsfm_num_left_chunks:
|
22 |
|
23 |
-
discriminator_dim:
|
24 |
discriminator_in_channel: 2
|
25 |
|
26 |
compress_factor: 0.3
|
|
|
6 |
win_size: 200
|
7 |
hop_size: 80
|
8 |
|
9 |
+
# 2**down_sampling_num_layers,
|
10 |
+
# 例如 2**5=32 就意味着 32个值在降采样之后是一个时间步,
|
11 |
+
# 则一步是 32/sample_rate = 0.004秒。
|
12 |
+
# 那么 tsfm_chunk_size=4 则为16ms,tsfm_chunk_size=8 则为32ms
|
13 |
+
# 假设每次向左看1秒,则:
|
14 |
+
# tsfm_chunk_size=1,tsfm_num_left_chunks: 256
|
15 |
+
# tsfm_chunk_size=4,tsfm_num_left_chunks: 64
|
16 |
+
# tsfm_chunk_size=8,tsfm_num_left_chunks: 32
|
17 |
down_sampling_num_layers: 5
|
18 |
down_sampling_in_channels: 1
|
19 |
down_sampling_hidden_channels: 64
|
20 |
down_sampling_kernel_size: 4
|
21 |
down_sampling_stride: 2
|
22 |
|
23 |
+
tsfm_hidden_size: 64
|
24 |
+
tsfm_attention_heads: 4
|
25 |
tsfm_num_blocks: 6
|
26 |
tsfm_dropout_rate: 0.1
|
27 |
+
tsfm_max_length: 5120
|
28 |
+
tsfm_chunk_size: 4
|
29 |
+
tsfm_num_left_chunks: 64
|
30 |
|
31 |
+
discriminator_dim: 32
|
32 |
discriminator_in_channel: 2
|
33 |
|
34 |
compress_factor: 0.3
|