File size: 3,971 Bytes
94ba8b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/modules.py#L816
"""
from typing import Tuple

import torch
import torch.nn as nn
from torch.nn import functional as F
import torchaudio


def local_energy(spec: torch.Tensor, n_frame: int, device: torch.device) -> torch.Tensor:
    if n_frame % 2 == 0:
        n_frame += 1
    n_frame_half = n_frame // 2

    # spec shape: [b, c, t, f, 2]
    spec = F.pad(spec.pow(2).sum(-1).sum(-1), (n_frame_half, n_frame_half, 0, 0))
    # spec shape: [b, c, t-pad]

    weight = torch.hann_window(n_frame, device=device, dtype=spec.dtype)
    # w shape: [n_frame]

    spec = spec.unfold(-1, size=n_frame, step=1) * weight
    # x shape: [b, c, t, n_frame]

    result = torch.sum(spec, dim=-1).div(n_frame)
    # result shape: [b, c, t]
    return result


def local_snr(spec_clean: torch.Tensor,
              spec_noise: torch.Tensor,
              n_frame: int = 5,
              db: bool = False,
              eps: float = 1e-12,
              ):
    # [b, c, t, f]
    spec_clean = torch.view_as_real(spec_clean)
    spec_noise = torch.view_as_real(spec_noise)
    # [b, c, t, f, 2]

    energy_clean = local_energy(spec_clean, n_frame=n_frame, device=spec_clean.device)
    energy_noise = local_energy(spec_noise, n_frame=n_frame, device=spec_noise.device)
    # [b, c, t]

    snr = energy_clean / energy_noise.clamp_min(eps)
    # snr shape: [b, c, t]

    if db:
        snr = snr.clamp_min(eps).log10().mul(10)
    return snr, energy_clean, energy_noise


class LocalSnrTarget(nn.Module):
    def __init__(self,
                 sample_rate: int = 8000,
                 nfft: int = 512,
                 win_size: int = 512,
                 hop_size: int = 256,

                 n_frame: int = 3,

                 min_local_snr: int = -15,
                 max_local_snr: int = 30,

                 db: bool = True,
                 ):
        super().__init__()
        self.sample_rate = sample_rate
        self.nfft = nfft
        self.win_size = win_size
        self.hop_size = hop_size

        self.n_frame = n_frame

        self.min_local_snr = min_local_snr
        self.max_local_snr = max_local_snr

        self.db = db

    def forward(self,
                spec_clean: torch.Tensor,
                spec_noise: torch.Tensor,
                ) -> torch.Tensor:
        """

        :param spec_clean: torch.complex, shape: [b, c, t, f]
        :param spec_noise: torch.complex, shape: [b, c, t, f]
        :return: lsnr, shape: [b, t]
        """

        lsnr, _, _ = local_snr(
            spec_clean=spec_clean,
            spec_noise=spec_noise,
            n_frame=self.n_frame,
            db=self.db,
        )
        # lsnr shape: [b, c, t]
        lsnr = lsnr.clamp(self.min_local_snr, self.max_local_snr).squeeze(1)
        # lsnr shape: [b, t]
        return lsnr


def main():
    sample_rate = 8000
    nfft = 512
    win_size = 512
    hop_size = 256
    window_fn = "hamming"

    transform = torchaudio.transforms.Spectrogram(
        n_fft=nfft,
        win_length=win_size,
        hop_length=hop_size,
        power=None,
        window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
    )

    noisy = torch.randn(size=(1, 16000), dtype=torch.float32)

    spec = transform.forward(noisy)
    spec = spec.permute(0, 2, 1)
    spec = torch.unsqueeze(spec, dim=1)
    print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")

    # [b, c, t, f]
    # spec = torch.view_as_real(spec)
    # [b, c, t, f, 2]

    local = LocalSnrTarget(
        sample_rate=sample_rate,
        nfft=nfft,
        win_size=win_size,
        hop_size=hop_size,
        n_frame=5,
        min_local_snr=-15,
        max_local_snr=30,
        db=True,
    )
    lsnr_target = local.forward(spec, spec)
    print(f"lsnr_target.shape: {lsnr_target.shape}")
    return


if __name__ == "__main__":
    main()