File size: 2,092 Bytes
f6b56a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from gyraudio.audio_separation.architecture.model import SeparationModel
from typing import Tuple


class FlatConvolutional(SeparationModel):
    """Convolutional neural network for audio separation,
    No decimation, no bottleneck, just basic signal processing
    """

    def __init__(self,
                 ch_in: int = 1,
                 ch_out: int = 2,
                 h_dim=16,
                 k_size=5,
                 dilation=1
                 ) -> None:
        super().__init__()
        self.conv1 = torch.nn.Conv1d(
            ch_in, h_dim, k_size, 
            dilation=dilation, padding=dilation*(k_size//2))
        self.conv2 = torch.nn.Conv1d(
            h_dim, h_dim, k_size,
            dilation=dilation, padding=dilation*(k_size//2))
        self.conv3 = torch.nn.Conv1d(
            h_dim, h_dim, k_size,
            dilation=dilation, padding=dilation*(k_size//2))
        self.conv4 = torch.nn.Conv1d(
            h_dim, h_dim, k_size,
            dilation=dilation, padding=dilation*(k_size//2))
        self.relu = torch.nn.ReLU()
        self.encoder = torch.nn.Sequential(
            self.conv1,
            self.relu,
            self.conv2,
            self.relu,
            self.conv3,
            self.relu,
            self.conv4,
            self.relu
        )
        self.demux = torch.nn.Sequential(*(
            torch.nn.Conv1d(h_dim, h_dim//2, 1),  # conv1x1
            torch.nn.ReLU(),
            torch.nn.Conv1d(h_dim//2, ch_out, 1),  # conv1x1
        ))

    def forward(self, mixed_sig_in: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Perform feature extraction followed by classifier head

        Args:
            sig_in (torch.Tensor): [N, C, T]

        Returns:
            torch.Tensor: logits (not probabilities) [N, n_classes]
        """
        # Convolution backbone
        # [N, C, T] -> [N, h, T]
        features = self.encoder(mixed_sig_in)
        # [N, h, T] -> [N, 2, T]
        demuxed = self.demux(features)
        return torch.chunk(demuxed, 2, dim=1)  # [N, 1, T], [N, 1, T]