balthou's picture
draft audio sep app
f6b56a2
import torch
from gyraudio.audio_separation.architecture.model import SeparationModel
from typing import Tuple
class FlatConvolutional(SeparationModel):
"""Convolutional neural network for audio separation,
No decimation, no bottleneck, just basic signal processing
"""
def __init__(self,
ch_in: int = 1,
ch_out: int = 2,
h_dim=16,
k_size=5,
dilation=1
) -> None:
super().__init__()
self.conv1 = torch.nn.Conv1d(
ch_in, h_dim, k_size,
dilation=dilation, padding=dilation*(k_size//2))
self.conv2 = torch.nn.Conv1d(
h_dim, h_dim, k_size,
dilation=dilation, padding=dilation*(k_size//2))
self.conv3 = torch.nn.Conv1d(
h_dim, h_dim, k_size,
dilation=dilation, padding=dilation*(k_size//2))
self.conv4 = torch.nn.Conv1d(
h_dim, h_dim, k_size,
dilation=dilation, padding=dilation*(k_size//2))
self.relu = torch.nn.ReLU()
self.encoder = torch.nn.Sequential(
self.conv1,
self.relu,
self.conv2,
self.relu,
self.conv3,
self.relu,
self.conv4,
self.relu
)
self.demux = torch.nn.Sequential(*(
torch.nn.Conv1d(h_dim, h_dim//2, 1), # conv1x1
torch.nn.ReLU(),
torch.nn.Conv1d(h_dim//2, ch_out, 1), # conv1x1
))
def forward(self, mixed_sig_in: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform feature extraction followed by classifier head
Args:
sig_in (torch.Tensor): [N, C, T]
Returns:
torch.Tensor: logits (not probabilities) [N, n_classes]
"""
# Convolution backbone
# [N, C, T] -> [N, h, T]
features = self.encoder(mixed_sig_in)
# [N, h, T] -> [N, 2, T]
demuxed = self.demux(features)
return torch.chunk(demuxed, 2, dim=1) # [N, 1, T], [N, 1, T]