File size: 4,111 Bytes
1e78a70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6713e7b
 
 
1e78a70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6713e7b
1e78a70
 
33aff71
1e78a70
33aff71
1e78a70
 
 
33aff71
1e78a70
 
 
 
33aff71
1e78a70
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import logging
from pathlib import Path
import shutil
import tempfile
import zipfile

import librosa
import numpy as np
import torch
import torchaudio

from project_settings import project_path
from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel, MODEL_FILE
from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft

logger = logging.getLogger("toolbox")


class InferenceMPNet(object):
    def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
        self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
        self.device = torch.device(device)

        logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
        config, generator = self.load_models(self.pretrained_model_path_or_zip_file)
        logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")

        self.config = config
        self.generator = generator
        self.generator.to(device)
        self.generator.eval()

    def load_models(self, model_path: str):
        model_path = Path(model_path)
        if model_path.name.endswith(".zip"):
            with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
                out_root = Path(tempfile.gettempdir()) / "nx_denoise"
                out_root.mkdir(parents=True, exist_ok=True)
                f_zip.extractall(path=out_root)
            model_path = out_root / model_path.stem

        config = MPNetConfig.from_pretrained(
            pretrained_model_name_or_path=model_path.as_posix(),
        )
        generator = MPNetPretrainedModel.from_pretrained(
            pretrained_model_name_or_path=model_path.as_posix(),
        )
        generator.to(self.device)
        generator.eval()

        shutil.rmtree(model_path)
        return config, generator

    def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
        noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
        noisy_audio = noisy_audio.unsqueeze(dim=0)

        # noisy_audio shape: [batch_size, n_samples]
        enhanced_audio = self.enhancement_by_tensor(noisy_audio)
        # noisy_audio shape: [n_samples,]
        return enhanced_audio.cpu().numpy()

    def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
        if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
            raise AssertionError(f"The value range of audio samples should be between -1 and 1.")

        noisy_audio = noisy_audio.to(self.device)

        with torch.no_grad():
            noisy_mag, noisy_pha, noisy_com = mag_pha_stft(
                noisy_audio,
                self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor
            )
            mag_g, pha_g, com_g = self.generator.forward(noisy_mag, noisy_pha)
            audio_g = mag_pha_istft(
                mag_g, pha_g,
                self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor
            )
            enhanced_audio = audio_g.detach()

        enhanced_audio = enhanced_audio[0]
        return enhanced_audio


def main():
    model_zip_file = project_path / "trained_models/mpnet-aishell-1-epoch.zip"
    infer_mpnet = InferenceMPNet(model_zip_file)

    sample_rate = 8000
    noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
    noisy_audio, _ = librosa.load(
        noisy_audio_file.as_posix(),
        sr=sample_rate,
    )
    noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
    noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
    noisy_audio = noisy_audio.unsqueeze(dim=0)

    enhanced_audio = infer_mpnet.enhancement_by_tensor(noisy_audio)

    filename = "enhanced_audio.wav"
    torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)

    return


if __name__ == '__main__':
    main()