Spaces:
Building
Building
import torch | |
from gyraudio.audio_separation.architecture.model import SeparationModel | |
from typing import Tuple | |
class FlatConvolutional(SeparationModel): | |
"""Convolutional neural network for audio separation, | |
No decimation, no bottleneck, just basic signal processing | |
""" | |
def __init__(self, | |
ch_in: int = 1, | |
ch_out: int = 2, | |
h_dim=16, | |
k_size=5, | |
dilation=1 | |
) -> None: | |
super().__init__() | |
self.conv1 = torch.nn.Conv1d( | |
ch_in, h_dim, k_size, | |
dilation=dilation, padding=dilation*(k_size//2)) | |
self.conv2 = torch.nn.Conv1d( | |
h_dim, h_dim, k_size, | |
dilation=dilation, padding=dilation*(k_size//2)) | |
self.conv3 = torch.nn.Conv1d( | |
h_dim, h_dim, k_size, | |
dilation=dilation, padding=dilation*(k_size//2)) | |
self.conv4 = torch.nn.Conv1d( | |
h_dim, h_dim, k_size, | |
dilation=dilation, padding=dilation*(k_size//2)) | |
self.relu = torch.nn.ReLU() | |
self.encoder = torch.nn.Sequential( | |
self.conv1, | |
self.relu, | |
self.conv2, | |
self.relu, | |
self.conv3, | |
self.relu, | |
self.conv4, | |
self.relu | |
) | |
self.demux = torch.nn.Sequential(*( | |
torch.nn.Conv1d(h_dim, h_dim//2, 1), # conv1x1 | |
torch.nn.ReLU(), | |
torch.nn.Conv1d(h_dim//2, ch_out, 1), # conv1x1 | |
)) | |
def forward(self, mixed_sig_in: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Perform feature extraction followed by classifier head | |
Args: | |
sig_in (torch.Tensor): [N, C, T] | |
Returns: | |
torch.Tensor: logits (not probabilities) [N, n_classes] | |
""" | |
# Convolution backbone | |
# [N, C, T] -> [N, h, T] | |
features = self.encoder(mixed_sig_in) | |
# [N, h, T] -> [N, 2, T] | |
demuxed = self.demux(features) | |
return torch.chunk(demuxed, 2, dim=1) # [N, 1, T], [N, 1, T] | |