Spaces:
Building
Building
import torch | |
from gyraudio.audio_separation.architecture.model import SeparationModel | |
from gyraudio.audio_separation.architecture.building_block import ResConvolution | |
from typing import Optional | |
# import logging | |
class EncoderSingleStage(torch.nn.Module): | |
""" | |
Extend channels | |
Resnet | |
Downsample by 2 | |
""" | |
def __init__(self, ch: int, ch_out: int, hdim: Optional[int] = None, k_size=5): | |
# ch_out ~ ch_in*extension_factor | |
super().__init__() | |
hdim = hdim or ch | |
self.extension_conv = torch.nn.Conv1d(ch, ch_out, k_size, padding=k_size//2) | |
self.res_conv = ResConvolution(ch_out, hdim=hdim, k_size=k_size) | |
# warning on maxpooling jitter offset! | |
self.max_pool = torch.nn.MaxPool1d(kernel_size=2) | |
def forward(self, x): | |
x = self.extension_conv(x) | |
x = self.res_conv(x) | |
x_ds = self.max_pool(x) | |
return x, x_ds | |
class DecoderSingleStage(torch.nn.Module): | |
""" | |
Upsample by 2 | |
Resnet | |
Extend channels | |
""" | |
def __init__(self, ch: int, ch_out: int, hdim: Optional[int] = None, k_size=5): | |
"""Decoder stage | |
Args: | |
ch (int): channel size (downsampled & skip connection have same channel size) | |
ch_out (int): number of output channels (shall match the number of input channels of the next stage) | |
hdim (Optional[int], optional): Hidden dimension used in the residual block. Defaults to None. | |
k_size (int, optional): Convolution size. Defaults to 5. | |
Notes: | |
====== | |
ch_out = 2*ch/extension_factor | |
self.scale_mixers_conv | |
- tells how lower decoded (x_ds) scale is merged with current encoded scale (x_skip) | |
- could be a pointwise aka conv1x1 | |
""" | |
super().__init__() | |
hdim = hdim or ch | |
self.scale_mixers_conv = torch.nn.Conv1d(2*ch, ch_out, k_size, padding=k_size//2) | |
self.res_conv = ResConvolution(ch_out, hdim=hdim, k_size=k_size) | |
# warning: Linear interpolation shall be "conjugated" with the skipping downsampling | |
# special care shall be taken care of regarding offsets | |
# https://arxiv.org/abs/1806.03185 | |
self.upsample = torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=True) | |
self.non_linearity = torch.nn.ReLU() | |
def forward(self, x_ds: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor: | |
"""""" | |
x_us = self.upsample(x_ds) # [N, ch, T/2] -> [N, ch, T] | |
x = torch.cat([x_us, x_skip], dim=1) # [N, 2.ch, T] | |
x = self.scale_mixers_conv(x) # [N, ch_out, T] | |
x = self.non_linearity(x) | |
x = self.res_conv(x) # [N, ch_out, T] | |
return x | |
class ResUNet(SeparationModel): | |
"""Convolutional neural network for audio separation, | |
Decimation, bottleneck | |
""" | |
def __init__(self, | |
ch_in: int = 1, | |
ch_out: int = 2, | |
channels_extension: float = 1.5, | |
h_dim=16, | |
k_size=5, | |
) -> None: | |
super().__init__() | |
self.need_split = ch_out != ch_in | |
self.ch_out = ch_out | |
self.source_modality_conv = torch.nn.Conv1d(ch_in, h_dim, k_size, padding=k_size//2) | |
self.encoder_list = torch.nn.ModuleList() | |
self.decoder_list = torch.nn.ModuleList() | |
self.non_linearity = torch.nn.ReLU() | |
h_dim_current = h_dim | |
for _level in range(4): | |
h_dim_ds = int(h_dim_current*channels_extension) | |
self.encoder_list.append(EncoderSingleStage(h_dim_current, h_dim_ds, k_size=k_size)) | |
self.decoder_list.append(DecoderSingleStage(h_dim_ds, h_dim_current, k_size=k_size)) | |
h_dim_current = h_dim_ds | |
self.bottleneck = ResConvolution(h_dim_current, k_size=k_size) | |
self.target_modality_conv = torch.nn.Conv1d(h_dim, ch_out, 1) # conv1x1 channel mixer | |
def forward(self, x_in): | |
# x_in (1, 2048) | |
x0 = self.source_modality_conv(x_in) | |
x0 = self.non_linearity(x0) | |
# x0 -> (16, 2048) | |
x1_skip, x1_ds = self.encoder_list[0](x0) | |
# x1_skip -> (24, 2048) | |
# x1_ds -> (24, 1024) | |
# print(x1_skip.shape, x1_ds.shape) | |
x2_skip, x2_ds = self.encoder_list[1](x1_ds) | |
# x2_skip -> (36, 1024) | |
# x2_ds -> (36, 512) | |
# print(x2_skip.shape, x2_ds.shape) | |
x3_skip, x3_ds = self.encoder_list[2](x2_ds) | |
# x3_skip -> (54, 512) | |
# x3_ds -> (54, 256) | |
# print(x3_skip.shape, x3_ds.shape) | |
x4_skip, x4_ds = self.encoder_list[3](x3_ds) | |
# x4_skip -> (81, 256) | |
# x4_ds -> (81, 128) | |
# print(x4_skip.shape, x4_ds.shape) | |
x4_dec = self.bottleneck(x4_ds) | |
x3_dec = self.decoder_list[3](x4_dec, x4_skip) | |
x2_dec = self.decoder_list[2](x3_dec, x3_skip) | |
x1_dec = self.decoder_list[1](x2_dec, x2_skip) | |
x0_dec = self.decoder_list[0](x1_dec, x1_skip) | |
demuxed = self.target_modality_conv(x0_dec) | |
# no relu | |
if self.need_split: | |
return torch.chunk(demuxed, self.ch_out, dim=1) | |
return demuxed, None | |
if __name__ == "__main__": | |
model = ResUNet() | |
inp = torch.rand(2, 1, 2048) | |
out = model(inp) | |
print(model) | |
print(model.count_parameters()) | |
print(out[0].shape) | |