File size: 4,385 Bytes
75c6e9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
'''
@File    :   subband_util.py
@Contact :   [email protected]
@License :   (C)Copyright 2020-2021
@Modify Time      @Author    @Version    @Desciption
------------      -------    --------    -----------
2020/4/3 4:54 PM   Haohe Liu      1.0         None
'''

import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import os.path as op
from scipy.io import loadmat


def load_mat2numpy(fname=""):
    '''
    Args:
        fname: pth to mat
        type:
    Returns: dic object
    '''
    if len(fname) == 0:
        return None
    else:
        return loadmat(fname)


class PQMF(nn.Module):
    def __init__(self, N, M, project_root):
        super().__init__()
        self.N = N  # nsubband
        self.M = M  # nfilter
        try:
            assert (N, M) in [(8, 64), (4, 64), (2, 64)]
        except:
            print("Warning:", N, "subbandand ", M, " filter is not supported")
        self.pad_samples = 64
        self.name = str(N) + "_" + str(M) + ".mat"
        self.ana_conv_filter = nn.Conv1d(
            1, out_channels=N, kernel_size=M, stride=N, bias=False
        )
        data = load_mat2numpy(op.join(project_root, "f_" + self.name))
        data = data['f'].astype(np.float32) / N
        data = np.flipud(data.T).T
        data = np.reshape(data, (N, 1, M)).copy()
        dict_new = self.ana_conv_filter.state_dict().copy()
        dict_new['weight'] = torch.from_numpy(data)
        self.ana_pad = nn.ConstantPad1d((M - N, 0), 0)
        self.ana_conv_filter.load_state_dict(dict_new)

        self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0)
        self.syn_conv_filter = nn.Conv1d(
            N, out_channels=N, kernel_size=M // N, stride=1, bias=False
        )
        gk = load_mat2numpy(op.join(project_root, "h_" + self.name))
        gk = gk['h'].astype(np.float32)
        gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N
        gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy()
        dict_new = self.syn_conv_filter.state_dict().copy()
        dict_new['weight'] = torch.from_numpy(gk)
        self.syn_conv_filter.load_state_dict(dict_new)

        for param in self.parameters():
            param.requires_grad = False

    def __analysis_channel(self, inputs):
        return self.ana_conv_filter(self.ana_pad(inputs))

    def __systhesis_channel(self, inputs):
        ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1)
        return torch.reshape(ret, (ret.shape[0], 1, -1))

    def analysis(self, inputs):
        '''
        :param inputs: [batchsize,channel,raw_wav],value:[0,1]
        :return:
        '''
        inputs = F.pad(inputs, ((0, self.pad_samples)))
        ret = None
        for i in range(inputs.size()[1]):  # channels
            if ret is None:
                ret = self.__analysis_channel(inputs[:, i : i + 1, :])
            else:
                ret = torch.cat(
                    (ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1
                )
        return ret

    def synthesis(self, data):
        '''
        :param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1]
        :return:
        '''
        ret = None
        # data = F.pad(data,((0,self.pad_samples//self.N)))
        for i in range(data.size()[1]):  # channels
            if i % self.N == 0:
                if ret is None:
                    ret = self.__systhesis_channel(data[:, i : i + self.N, :])
                else:
                    new = self.__systhesis_channel(data[:, i : i + self.N, :])
                    ret = torch.cat((ret, new), dim=1)
        ret = ret[..., : -self.pad_samples]
        return ret

    def forward(self, inputs):
        return self.ana_conv_filter(self.ana_pad(inputs))


if __name__ == "__main__":
    import torch
    import numpy as np
    import matplotlib.pyplot as plt
    from tools.file.wav import *

    pqmf = PQMF(N=4, M=64, project_root="/Users/admin/Documents/projects")

    rs = np.random.RandomState(0)
    x = torch.tensor(rs.rand(4, 2, 32000), dtype=torch.float32)

    a1 = pqmf.analysis(x)
    a2 = pqmf.synthesis(a1)

    print(a2.size(), x.size())

    plt.subplot(211)
    plt.plot(x[0, 0, -500:])
    plt.subplot(212)
    plt.plot(a2[0, 0, -500:])
    plt.plot(x[0, 0, -500:] - a2[0, 0, -500:])
    plt.show()

    print(torch.sum(torch.abs(x[...] - a2[...])))