Spaces:
Building
Building
File size: 2,092 Bytes
f6b56a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
from gyraudio.audio_separation.architecture.model import SeparationModel
from typing import Tuple
class FlatConvolutional(SeparationModel):
"""Convolutional neural network for audio separation,
No decimation, no bottleneck, just basic signal processing
"""
def __init__(self,
ch_in: int = 1,
ch_out: int = 2,
h_dim=16,
k_size=5,
dilation=1
) -> None:
super().__init__()
self.conv1 = torch.nn.Conv1d(
ch_in, h_dim, k_size,
dilation=dilation, padding=dilation*(k_size//2))
self.conv2 = torch.nn.Conv1d(
h_dim, h_dim, k_size,
dilation=dilation, padding=dilation*(k_size//2))
self.conv3 = torch.nn.Conv1d(
h_dim, h_dim, k_size,
dilation=dilation, padding=dilation*(k_size//2))
self.conv4 = torch.nn.Conv1d(
h_dim, h_dim, k_size,
dilation=dilation, padding=dilation*(k_size//2))
self.relu = torch.nn.ReLU()
self.encoder = torch.nn.Sequential(
self.conv1,
self.relu,
self.conv2,
self.relu,
self.conv3,
self.relu,
self.conv4,
self.relu
)
self.demux = torch.nn.Sequential(*(
torch.nn.Conv1d(h_dim, h_dim//2, 1), # conv1x1
torch.nn.ReLU(),
torch.nn.Conv1d(h_dim//2, ch_out, 1), # conv1x1
))
def forward(self, mixed_sig_in: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform feature extraction followed by classifier head
Args:
sig_in (torch.Tensor): [N, C, T]
Returns:
torch.Tensor: logits (not probabilities) [N, n_classes]
"""
# Convolution backbone
# [N, C, T] -> [N, h, T]
features = self.encoder(mixed_sig_in)
# [N, h, T] -> [N, 2, T]
demuxed = self.demux(features)
return torch.chunk(demuxed, 2, dim=1) # [N, 1, T], [N, 1, T]
|