Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import logging | |
from pathlib import Path | |
import shutil | |
import tempfile, time | |
import zipfile | |
import librosa | |
import numpy as np | |
import torch | |
import torchaudio | |
torch.set_num_threads(1) | |
from project_settings import project_path | |
from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config | |
from toolbox.torchaudio.models.dfnet2.modeling_dfnet2 import DfNet2PretrainedModel, MODEL_FILE | |
logger = logging.getLogger("toolbox") | |
class InferenceDfNet(object): | |
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"): | |
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file | |
self.device = torch.device(device) | |
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}") | |
config, model = self.load_models(self.pretrained_model_path_or_zip_file) | |
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}") | |
self.config = config | |
self.model = model | |
self.model.to(device) | |
self.model.eval() | |
def load_models(self, model_path: str): | |
model_path = Path(model_path) | |
if model_path.name.endswith(".zip"): | |
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip: | |
out_root = Path(tempfile.gettempdir()) / "nx_denoise" | |
out_root.mkdir(parents=True, exist_ok=True) | |
f_zip.extractall(path=out_root) | |
model_path = out_root / model_path.stem | |
config = DfNet2Config.from_pretrained( | |
pretrained_model_name_or_path=model_path.as_posix(), | |
) | |
model = DfNet2PretrainedModel.from_pretrained( | |
pretrained_model_name_or_path=model_path.as_posix(), | |
) | |
model.to(self.device) | |
model.eval() | |
shutil.rmtree(model_path) | |
return config, model | |
def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray: | |
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) | |
noisy_audio = noisy_audio.unsqueeze(dim=0) | |
# noisy_audio shape: [batch_size, n_samples] | |
enhanced_audio = self.denoise_offline(noisy_audio) | |
# enhanced_audio shape: [channels, num_samples] | |
enhanced_audio = enhanced_audio[0] | |
# enhanced_audio shape: [num_samples] | |
return enhanced_audio.cpu().numpy() | |
def denoise_offline(self, noisy_audio: torch.Tensor) -> torch.Tensor: | |
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: | |
raise AssertionError(f"The value range of audio samples should be between -1 and 1.") | |
# noisy_audio shape: [batch_size, num_samples] | |
noisy_audios = noisy_audio.to(self.device) | |
with torch.no_grad(): | |
est_spec, est_wav, est_mask, lsnr = self.model.forward(noisy_audios) | |
# shape: [batch_size, 1, num_samples] | |
denoise = est_wav[0] | |
# shape: [channels, num_samples] | |
return denoise | |
def denoise_online(self, noisy_audio: torch.Tensor) -> torch.Tensor: | |
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: | |
raise AssertionError(f"The value range of audio samples should be between -1 and 1.") | |
# noisy_audio shape: [batch_size, num_samples] | |
noisy_audios = noisy_audio.to(self.device) | |
with torch.no_grad(): | |
denoise = self.model.forward_chunk_by_chunk(noisy_audios) | |
# shape: [batch_size, 1, num_samples] | |
denoise = denoise[0] | |
# shape: [channels, num_samples] | |
return denoise | |
def main(): | |
model_zip_file = project_path / "trained_models/dfnet2-nx-dns3.zip" | |
infer_model = InferenceDfNet(model_zip_file) | |
sample_rate = 8000 | |
noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav" | |
noisy_audio, sample_rate = librosa.load( | |
noisy_audio_file.as_posix(), | |
sr=sample_rate, | |
) | |
duration = librosa.get_duration(y=noisy_audio, sr=sample_rate) | |
# noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)] | |
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) | |
noisy_audio = noisy_audio.unsqueeze(dim=0) | |
begin = time.time() | |
enhanced_audio = infer_model.denoise_offline(noisy_audio) | |
time_cost = time.time() - begin | |
print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") | |
filename = "enhanced_audio_offline.wav" | |
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate) | |
begin = time.time() | |
enhanced_audio = infer_model.denoise_online(noisy_audio) | |
time_cost = time.time() - begin | |
print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") | |
filename = "enhanced_audio_online.wav" | |
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate) | |
return | |
if __name__ == "__main__": | |
main() | |