File size: 1,587 Bytes
f6b56a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from typing import List


class FilterBank(torch.nn.Module):
    """Convolution filter bank (linear)
    Serves as an embedding for the audio signal
    """

    def __init__(self, ch_in: int, out_dim=16, k_size=5, dilation_list: List[int] = [1, 2, 4, 8]):
        super().__init__()
        self.out_dim = out_dim
        self.source_modality_conv = torch.nn.ModuleList()
        for dilation in dilation_list:
            self.source_modality_conv.append(
                torch.nn.Conv1d(ch_in, out_dim//len(dilation_list), k_size, dilation=dilation, padding=(dilation*(k_size//2)))
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.cat([conv(x) for conv in self.source_modality_conv], axis=1)
        assert out.shape[1] == self.out_dim
        return out


class ResConvolution(torch.nn.Module):
    """ResNet building block
    https://paperswithcode.com/method/residual-connection
    """

    def __init__(self, ch, hdim=None, k_size=5):
        super().__init__()
        hdim = hdim or ch
        self.conv1 = torch.nn.Conv1d(ch, hdim, k_size, padding=k_size//2)
        self.conv2 = torch.nn.Conv1d(hdim, ch, k_size, padding=k_size//2)
        self.non_linearity = torch.nn.ReLU()

    def forward(self, x_in):
        x = self.conv1(x_in)
        x = self.non_linearity(x)
        x = self.conv2(x)
        x += x_in
        x = self.non_linearity(x)
        return x


if __name__ == "__main__":
    model = FilterBank(1, 16)
    inp = torch.rand(2, 1, 2048)
    out = model(inp)
    print(model)
    print(out[0].shape)