HoneyTian's picture
update
bd3d872
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import logging
from pathlib import Path
import shutil
import tempfile, time
import zipfile
import librosa
import numpy as np
import torch
import torchaudio
torch.set_num_threads(1)
from project_settings import project_path
from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNPretrainedModel, MODEL_FILE
logger = logging.getLogger("toolbox")
class InferenceDTLN(object):
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
self.device = torch.device(device)
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
config, model = self.load_models(self.pretrained_model_path_or_zip_file)
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
self.config = config
self.model = model
self.model.to(device)
self.model.eval()
def load_models(self, model_path: str):
model_path = Path(model_path)
if model_path.name.endswith(".zip"):
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
out_root = Path(tempfile.gettempdir()) / "nx_denoise"
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
model_path = out_root / model_path.stem
config = DTLNConfig.from_pretrained(
pretrained_model_name_or_path=model_path.as_posix(),
)
model = DTLNPretrainedModel.from_pretrained(
pretrained_model_name_or_path=model_path.as_posix(),
)
model.to(self.device)
model.eval()
shutil.rmtree(model_path)
return config, model
def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
noisy_audio = noisy_audio.unsqueeze(dim=0)
# noisy_audio shape: [batch_size, n_samples]
enhanced_audio = self.denoise_offline(noisy_audio)
# enhanced_audio shape: [channels, num_samples]
enhanced_audio = enhanced_audio[0]
# enhanced_audio shape: [num_samples]
return enhanced_audio.cpu().numpy()
def denoise_offline(self, noisy_audio: torch.Tensor) -> torch.Tensor:
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
# noisy_audio shape: [batch_size, num_samples]
noisy_audios = noisy_audio.to(self.device)
with torch.no_grad():
denoise = self.model.forward(noisy_audios)
# denoise shape: [batch_size, 1, num_samples]
denoise = denoise[0]
# shape: [channels, num_samples]
return denoise
def denoise_online(self, noisy_audio: torch.Tensor) -> torch.Tensor:
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
# noisy_audio shape: [batch_size, num_samples]
noisy_audios = noisy_audio.to(self.device)
with torch.no_grad():
denoise = self.model.forward_chunk_by_chunk(noisy_audios)
# denoise shape: [batch_size, 1, num_samples]
denoise = denoise[0]
# shape: [channels, num_samples]
return denoise
def main():
model_zip_file = project_path / "trained_models/dtln-nx-dns3.zip"
infer_model = InferenceDTLN(model_zip_file)
sample_rate = 8000
noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav"
noisy_audio, sample_rate = librosa.load(
noisy_audio_file.as_posix(),
sr=sample_rate,
)
duration = librosa.get_duration(y=noisy_audio, sr=sample_rate)
# noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
noisy_audio = noisy_audio.unsqueeze(dim=0)
# offline
begin = time.time()
enhanced_audio = infer_model.denoise_offline(noisy_audio)
time_cost = time.time() - begin
print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
filename = "enhanced_audio_offline.wav"
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
# online
begin = time.time()
enhanced_audio = infer_model.denoise_online(noisy_audio)
time_cost = time.time() - begin
print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
filename = "enhanced_audio_online.wav"
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
return
if __name__ == "__main__":
main()