Spaces:
Running
Running
File size: 4,111 Bytes
1e78a70 6713e7b 1e78a70 6713e7b 1e78a70 33aff71 1e78a70 33aff71 1e78a70 33aff71 1e78a70 33aff71 1e78a70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import logging
from pathlib import Path
import shutil
import tempfile
import zipfile
import librosa
import numpy as np
import torch
import torchaudio
from project_settings import project_path
from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel, MODEL_FILE
from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
logger = logging.getLogger("toolbox")
class InferenceMPNet(object):
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
self.device = torch.device(device)
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
config, generator = self.load_models(self.pretrained_model_path_or_zip_file)
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
self.config = config
self.generator = generator
self.generator.to(device)
self.generator.eval()
def load_models(self, model_path: str):
model_path = Path(model_path)
if model_path.name.endswith(".zip"):
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
out_root = Path(tempfile.gettempdir()) / "nx_denoise"
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
model_path = out_root / model_path.stem
config = MPNetConfig.from_pretrained(
pretrained_model_name_or_path=model_path.as_posix(),
)
generator = MPNetPretrainedModel.from_pretrained(
pretrained_model_name_or_path=model_path.as_posix(),
)
generator.to(self.device)
generator.eval()
shutil.rmtree(model_path)
return config, generator
def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
noisy_audio = noisy_audio.unsqueeze(dim=0)
# noisy_audio shape: [batch_size, n_samples]
enhanced_audio = self.enhancement_by_tensor(noisy_audio)
# noisy_audio shape: [n_samples,]
return enhanced_audio.cpu().numpy()
def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
noisy_audio = noisy_audio.to(self.device)
with torch.no_grad():
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(
noisy_audio,
self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor
)
mag_g, pha_g, com_g = self.generator.forward(noisy_mag, noisy_pha)
audio_g = mag_pha_istft(
mag_g, pha_g,
self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor
)
enhanced_audio = audio_g.detach()
enhanced_audio = enhanced_audio[0]
return enhanced_audio
def main():
model_zip_file = project_path / "trained_models/mpnet-aishell-1-epoch.zip"
infer_mpnet = InferenceMPNet(model_zip_file)
sample_rate = 8000
noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
noisy_audio, _ = librosa.load(
noisy_audio_file.as_posix(),
sr=sample_rate,
)
noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
noisy_audio = noisy_audio.unsqueeze(dim=0)
enhanced_audio = infer_mpnet.enhancement_by_tensor(noisy_audio)
filename = "enhanced_audio.wav"
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
return
if __name__ == '__main__':
main()
|