Spaces:
Running
Running
File size: 5,036 Bytes
bd3d872 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import logging
from pathlib import Path
import shutil
import tempfile, time
import zipfile
import librosa
import numpy as np
import torch
import torchaudio
torch.set_num_threads(1)
from project_settings import project_path
from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNPretrainedModel, MODEL_FILE
logger = logging.getLogger("toolbox")
class InferenceDTLN(object):
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
self.device = torch.device(device)
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
config, model = self.load_models(self.pretrained_model_path_or_zip_file)
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
self.config = config
self.model = model
self.model.to(device)
self.model.eval()
def load_models(self, model_path: str):
model_path = Path(model_path)
if model_path.name.endswith(".zip"):
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
out_root = Path(tempfile.gettempdir()) / "nx_denoise"
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
model_path = out_root / model_path.stem
config = DTLNConfig.from_pretrained(
pretrained_model_name_or_path=model_path.as_posix(),
)
model = DTLNPretrainedModel.from_pretrained(
pretrained_model_name_or_path=model_path.as_posix(),
)
model.to(self.device)
model.eval()
shutil.rmtree(model_path)
return config, model
def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
noisy_audio = noisy_audio.unsqueeze(dim=0)
# noisy_audio shape: [batch_size, n_samples]
enhanced_audio = self.denoise_offline(noisy_audio)
# enhanced_audio shape: [channels, num_samples]
enhanced_audio = enhanced_audio[0]
# enhanced_audio shape: [num_samples]
return enhanced_audio.cpu().numpy()
def denoise_offline(self, noisy_audio: torch.Tensor) -> torch.Tensor:
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
# noisy_audio shape: [batch_size, num_samples]
noisy_audios = noisy_audio.to(self.device)
with torch.no_grad():
denoise = self.model.forward(noisy_audios)
# denoise shape: [batch_size, 1, num_samples]
denoise = denoise[0]
# shape: [channels, num_samples]
return denoise
def denoise_online(self, noisy_audio: torch.Tensor) -> torch.Tensor:
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
# noisy_audio shape: [batch_size, num_samples]
noisy_audios = noisy_audio.to(self.device)
with torch.no_grad():
denoise = self.model.forward_chunk_by_chunk(noisy_audios)
# denoise shape: [batch_size, 1, num_samples]
denoise = denoise[0]
# shape: [channels, num_samples]
return denoise
def main():
model_zip_file = project_path / "trained_models/dtln-nx-dns3.zip"
infer_model = InferenceDTLN(model_zip_file)
sample_rate = 8000
noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav"
noisy_audio, sample_rate = librosa.load(
noisy_audio_file.as_posix(),
sr=sample_rate,
)
duration = librosa.get_duration(y=noisy_audio, sr=sample_rate)
# noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
noisy_audio = noisy_audio.unsqueeze(dim=0)
# offline
begin = time.time()
enhanced_audio = infer_model.denoise_offline(noisy_audio)
time_cost = time.time() - begin
print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
filename = "enhanced_audio_offline.wav"
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
# online
begin = time.time()
enhanced_audio = infer_model.denoise_online(noisy_audio)
time_cost = time.time() - begin
print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
filename = "enhanced_audio_online.wav"
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
return
if __name__ == "__main__":
main()
|