File size: 3,456 Bytes
69fa971
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import math

import numpy as np


def freq2erb(freq_hz: float) -> float:
    """
    https://www.cnblogs.com/LXP-Never/p/16011229.html
    1 / (24.7 * 9.265) = 0.00436976
    """
    return 9.265 * math.log(freq_hz / (24.7 * 9.265) + 1)


def erb2freq(n_erb: float) -> float:
    return 24.7 * 9.265 * (math.exp(n_erb / 9.265) - 1)


def get_erb_widths(sample_rate: int, fft_size: int, erb_bins: int, min_freq_bins_for_erb: int) -> np.ndarray:
    """
    https://github.com/Rikorose/DeepFilterNet/blob/main/libDF/src/lib.rs
    :param sample_rate:
    :param fft_size:
    :param erb_bins: erb (Equivalent Rectangular Bandwidth) 等效矩形带宽的通道数.
    :param min_freq_bins_for_erb: Minimum number of frequency bands per erb band
    :return:
    """
    nyq_freq = sample_rate / 2.
    freq_width: float = sample_rate / fft_size

    min_erb: float = freq2erb(0.)
    max_erb: float = freq2erb(nyq_freq)

    erb = [0] * erb_bins
    step = (max_erb - min_erb) / erb_bins

    prev_freq_bin = 0
    freq_over = 0
    for i in range(1, erb_bins + 1):
        f = erb2freq(min_erb + i * step)
        freq_bin = int(round(f / freq_width))
        freq_bins = freq_bin - prev_freq_bin - freq_over

        if freq_bins < min_freq_bins_for_erb:
            freq_over = min_freq_bins_for_erb - freq_bins
            freq_bins = min_freq_bins_for_erb
        else:
            freq_over = 0
        erb[i - 1] = freq_bins
        prev_freq_bin = freq_bin

    erb[erb_bins - 1] += 1
    too_large = sum(erb) - (fft_size / 2 + 1)
    if too_large > 0:
        erb[erb_bins - 1] -= too_large
    return np.array(erb, dtype=np.uint64)


def get_erb_filter_bank(erb_widths: np.ndarray,
                        sample_rate: int,
                        normalized: bool = True,
                        inverse: bool = False,
                        ):
    num_freq_bins = int(np.sum(erb_widths))
    num_erb_bins = len(erb_widths)

    fb: np.ndarray = np.zeros(shape=(num_freq_bins, num_erb_bins))

    points = np.cumsum([0] + erb_widths.tolist()).astype(int)[:-1]
    for i, (b, w) in enumerate(zip(points.tolist(), erb_widths.tolist())):
        fb[b: b + w, i] = 1

    if inverse:
        fb = fb.T
        if not normalized:
            fb /= np.sum(fb, axis=1, keepdims=True)
    else:
        if normalized:
            fb /= np.sum(fb, axis=0)
    return fb


def spec2erb(spec: np.ndarray, erb_fb: np.ndarray, db: bool = True):
    """
    ERB filterbank and transform to decibel scale.

    :param spec: Spectrum of shape [B, C, T, F].
    :param erb_fb: ERB filterbank array of shape [B] containing the ERB widths,
            where B are the number of ERB bins.
    :param db: Whether to transform the output into decibel scale. Defaults to `True`.
    :return:
    """
    # complex spec to power spec. (real * real + image * image)
    spec_ = np.abs(spec) ** 2

    # spec to erb feature.
    erb_feat = np.matmul(spec_, erb_fb)

    if db:
        erb_feat = 10 * np.log10(erb_feat + 1e-10)

    erb_feat = np.array(erb_feat, dtype=np.float32)
    return erb_feat


def main():
    erb_widths = get_erb_widths(
        sample_rate=8000,
        fft_size=512,
        erb_bins=32,
        min_freq_bins_for_erb=2,
    )
    erb_fb = get_erb_filter_bank(
        erb_widths=erb_widths,
        sample_rate=8000,
    )
    print(erb_fb.shape)

    return


if __name__ == "__main__":
    main()