balthou's picture
draft audio sep app
f6b56a2
import torch
from gyraudio.audio_separation.architecture.model import SeparationModel
from gyraudio.audio_separation.architecture.building_block import ResConvolution
from typing import Optional
# import logging
class EncoderSingleStage(torch.nn.Module):
"""
Extend channels
Resnet
Downsample by 2
"""
def __init__(self, ch: int, ch_out: int, hdim: Optional[int] = None, k_size=5):
# ch_out ~ ch_in*extension_factor
super().__init__()
hdim = hdim or ch
self.extension_conv = torch.nn.Conv1d(ch, ch_out, k_size, padding=k_size//2)
self.res_conv = ResConvolution(ch_out, hdim=hdim, k_size=k_size)
# warning on maxpooling jitter offset!
self.max_pool = torch.nn.MaxPool1d(kernel_size=2)
def forward(self, x):
x = self.extension_conv(x)
x = self.res_conv(x)
x_ds = self.max_pool(x)
return x, x_ds
class DecoderSingleStage(torch.nn.Module):
"""
Upsample by 2
Resnet
Extend channels
"""
def __init__(self, ch: int, ch_out: int, hdim: Optional[int] = None, k_size=5):
"""Decoder stage
Args:
ch (int): channel size (downsampled & skip connection have same channel size)
ch_out (int): number of output channels (shall match the number of input channels of the next stage)
hdim (Optional[int], optional): Hidden dimension used in the residual block. Defaults to None.
k_size (int, optional): Convolution size. Defaults to 5.
Notes:
======
ch_out = 2*ch/extension_factor
self.scale_mixers_conv
- tells how lower decoded (x_ds) scale is merged with current encoded scale (x_skip)
- could be a pointwise aka conv1x1
"""
super().__init__()
hdim = hdim or ch
self.scale_mixers_conv = torch.nn.Conv1d(2*ch, ch_out, k_size, padding=k_size//2)
self.res_conv = ResConvolution(ch_out, hdim=hdim, k_size=k_size)
# warning: Linear interpolation shall be "conjugated" with the skipping downsampling
# special care shall be taken care of regarding offsets
# https://arxiv.org/abs/1806.03185
self.upsample = torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
self.non_linearity = torch.nn.ReLU()
def forward(self, x_ds: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
""""""
x_us = self.upsample(x_ds) # [N, ch, T/2] -> [N, ch, T]
x = torch.cat([x_us, x_skip], dim=1) # [N, 2.ch, T]
x = self.scale_mixers_conv(x) # [N, ch_out, T]
x = self.non_linearity(x)
x = self.res_conv(x) # [N, ch_out, T]
return x
class ResUNet(SeparationModel):
"""Convolutional neural network for audio separation,
Decimation, bottleneck
"""
def __init__(self,
ch_in: int = 1,
ch_out: int = 2,
channels_extension: float = 1.5,
h_dim=16,
k_size=5,
) -> None:
super().__init__()
self.need_split = ch_out != ch_in
self.ch_out = ch_out
self.source_modality_conv = torch.nn.Conv1d(ch_in, h_dim, k_size, padding=k_size//2)
self.encoder_list = torch.nn.ModuleList()
self.decoder_list = torch.nn.ModuleList()
self.non_linearity = torch.nn.ReLU()
h_dim_current = h_dim
for _level in range(4):
h_dim_ds = int(h_dim_current*channels_extension)
self.encoder_list.append(EncoderSingleStage(h_dim_current, h_dim_ds, k_size=k_size))
self.decoder_list.append(DecoderSingleStage(h_dim_ds, h_dim_current, k_size=k_size))
h_dim_current = h_dim_ds
self.bottleneck = ResConvolution(h_dim_current, k_size=k_size)
self.target_modality_conv = torch.nn.Conv1d(h_dim, ch_out, 1) # conv1x1 channel mixer
def forward(self, x_in):
# x_in (1, 2048)
x0 = self.source_modality_conv(x_in)
x0 = self.non_linearity(x0)
# x0 -> (16, 2048)
x1_skip, x1_ds = self.encoder_list[0](x0)
# x1_skip -> (24, 2048)
# x1_ds -> (24, 1024)
# print(x1_skip.shape, x1_ds.shape)
x2_skip, x2_ds = self.encoder_list[1](x1_ds)
# x2_skip -> (36, 1024)
# x2_ds -> (36, 512)
# print(x2_skip.shape, x2_ds.shape)
x3_skip, x3_ds = self.encoder_list[2](x2_ds)
# x3_skip -> (54, 512)
# x3_ds -> (54, 256)
# print(x3_skip.shape, x3_ds.shape)
x4_skip, x4_ds = self.encoder_list[3](x3_ds)
# x4_skip -> (81, 256)
# x4_ds -> (81, 128)
# print(x4_skip.shape, x4_ds.shape)
x4_dec = self.bottleneck(x4_ds)
x3_dec = self.decoder_list[3](x4_dec, x4_skip)
x2_dec = self.decoder_list[2](x3_dec, x3_skip)
x1_dec = self.decoder_list[1](x2_dec, x2_skip)
x0_dec = self.decoder_list[0](x1_dec, x1_skip)
demuxed = self.target_modality_conv(x0_dec)
# no relu
if self.need_split:
return torch.chunk(demuxed, self.ch_out, dim=1)
return demuxed, None
if __name__ == "__main__":
model = ResUNet()
inp = torch.rand(2, 1, 2048)
out = model(inp)
print(model)
print(model.count_parameters())
print(out[0].shape)