File size: 5,036 Bytes
bd3d872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import logging
from pathlib import Path
import shutil
import tempfile, time
import zipfile

import librosa
import numpy as np
import torch
import torchaudio

torch.set_num_threads(1)

from project_settings import project_path
from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNPretrainedModel, MODEL_FILE

logger = logging.getLogger("toolbox")


class InferenceDTLN(object):
    def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
        self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
        self.device = torch.device(device)

        logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
        config, model = self.load_models(self.pretrained_model_path_or_zip_file)
        logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")

        self.config = config
        self.model = model
        self.model.to(device)
        self.model.eval()

    def load_models(self, model_path: str):
        model_path = Path(model_path)
        if model_path.name.endswith(".zip"):
            with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
                out_root = Path(tempfile.gettempdir()) / "nx_denoise"
                out_root.mkdir(parents=True, exist_ok=True)
                f_zip.extractall(path=out_root)
            model_path = out_root / model_path.stem

        config = DTLNConfig.from_pretrained(
            pretrained_model_name_or_path=model_path.as_posix(),
        )
        model = DTLNPretrainedModel.from_pretrained(
            pretrained_model_name_or_path=model_path.as_posix(),
        )
        model.to(self.device)
        model.eval()

        shutil.rmtree(model_path)
        return config, model

    def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
        noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
        noisy_audio = noisy_audio.unsqueeze(dim=0)

        # noisy_audio shape: [batch_size, n_samples]
        enhanced_audio = self.denoise_offline(noisy_audio)
        # enhanced_audio shape: [channels, num_samples]
        enhanced_audio = enhanced_audio[0]
        # enhanced_audio shape: [num_samples]
        return enhanced_audio.cpu().numpy()

    def denoise_offline(self, noisy_audio: torch.Tensor) -> torch.Tensor:
        if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
            raise AssertionError(f"The value range of audio samples should be between -1 and 1.")

        # noisy_audio shape: [batch_size, num_samples]
        noisy_audios = noisy_audio.to(self.device)

        with torch.no_grad():
            denoise = self.model.forward(noisy_audios)

        # denoise shape: [batch_size, 1, num_samples]
        denoise = denoise[0]
        # shape: [channels, num_samples]
        return denoise

    def denoise_online(self, noisy_audio: torch.Tensor) -> torch.Tensor:
        if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
            raise AssertionError(f"The value range of audio samples should be between -1 and 1.")

        # noisy_audio shape: [batch_size, num_samples]
        noisy_audios = noisy_audio.to(self.device)

        with torch.no_grad():
            denoise = self.model.forward_chunk_by_chunk(noisy_audios)

        # denoise shape: [batch_size, 1, num_samples]
        denoise = denoise[0]
        # shape: [channels, num_samples]
        return denoise


def main():
    model_zip_file = project_path / "trained_models/dtln-nx-dns3.zip"
    infer_model = InferenceDTLN(model_zip_file)

    sample_rate = 8000
    noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav"
    noisy_audio, sample_rate = librosa.load(
        noisy_audio_file.as_posix(),
        sr=sample_rate,
    )
    duration = librosa.get_duration(y=noisy_audio, sr=sample_rate)
    # noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
    noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
    noisy_audio = noisy_audio.unsqueeze(dim=0)

    # offline
    begin = time.time()
    enhanced_audio = infer_model.denoise_offline(noisy_audio)
    time_cost = time.time() - begin
    print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")

    filename = "enhanced_audio_offline.wav"
    torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)

    # online
    begin = time.time()
    enhanced_audio = infer_model.denoise_online(noisy_audio)
    time_cost = time.time() - begin
    print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")

    filename = "enhanced_audio_online.wav"
    torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)

    return


if __name__ == "__main__":
    main()