balthou's picture
draft audio sep app
f6b56a2
import torch
from gyraudio.audio_separation.architecture.model import SeparationModel
from typing import Optional, Tuple
def get_non_linearity(activation: str):
if activation == "LeakyReLU":
non_linearity = torch.nn.LeakyReLU()
else:
non_linearity = torch.nn.ReLU()
return non_linearity
class BaseConvolutionBlock(torch.nn.Module):
def __init__(self, ch_in, ch_out: int, k_size: int, activation="LeakyReLU", dropout: float = 0, bias: bool = True) -> None:
super().__init__()
self.conv = torch.nn.Conv1d(ch_in, ch_out, k_size, padding=k_size//2, bias=bias)
self.non_linearity = get_non_linearity(activation)
self.dropout = torch.nn.Dropout1d(p=dropout)
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
x = self.conv(x_in) # [N, ch_in, T] -> [N, ch_in+channels_extension, T]
x = self.non_linearity(x)
x = self.dropout(x)
return x
class EncoderStage(torch.nn.Module):
"""Conv (and extend channels), downsample 2 by skipping samples
"""
def __init__(self, ch_in: int, ch_out: int, k_size: int = 15, dropout: float = 0, bias: bool = True) -> None:
super().__init__()
self.conv_block = BaseConvolutionBlock(ch_in, ch_out, k_size=k_size, dropout=dropout, bias=bias)
def forward(self, x):
x = self.conv_block(x)
x_ds = x[..., ::2]
# ch_out = ch_in+channels_extension
return x, x_ds
class DecoderStage(torch.nn.Module):
"""Upsample by 2, Concatenate with skip connection, Conv (and shrink channels)
"""
def __init__(self, ch_in: int, ch_out: int, k_size: int = 5, dropout: float = 0., bias: bool = True) -> None:
"""Decoder stage
"""
super().__init__()
self.conv_block = BaseConvolutionBlock(ch_in, ch_out, k_size=k_size, dropout=dropout, bias=bias)
self.upsample = torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
def forward(self, x_ds: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
""""""
x_us = self.upsample(x_ds) # [N, ch, T/2] -> [N, ch, T]
x = torch.cat([x_us, x_skip], dim=1) # [N, 2.ch, T]
x = self.conv_block(x) # [N, ch_out, T]
return x
class WaveUNet(SeparationModel):
"""UNET in temporal domain (waveform)
= Multiscale convolutional neural network for audio separation
https://arxiv.org/abs/1806.03185
"""
def __init__(self,
ch_in: int = 1,
ch_out: int = 2,
channels_extension: int = 24,
k_conv_ds: int = 15,
k_conv_us: int = 5,
num_layers: int = 6,
dropout: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
self.need_split = ch_out != ch_in
self.ch_out = ch_out
self.encoder_list = torch.nn.ModuleList()
self.decoder_list = torch.nn.ModuleList()
# Defining first encoder
self.encoder_list.append(EncoderStage(ch_in, channels_extension, k_size=k_conv_ds, dropout=dropout, bias=bias))
for level in range(1, num_layers+1):
ch_i = level*channels_extension
ch_o = (level+1)*channels_extension
if level < num_layers:
# Skipping last encoder since we defined the first one outside the loop
self.encoder_list.append(EncoderStage(ch_i, ch_o, k_size=k_conv_ds, dropout=dropout, bias=bias))
self.decoder_list.append(DecoderStage(ch_o+ch_i, ch_i, k_size=k_conv_us, dropout=dropout, bias=bias))
self.bottleneck = BaseConvolutionBlock(
num_layers*channels_extension,
(num_layers+1)*channels_extension,
k_size=k_conv_ds,
dropout=dropout,
bias=bias)
self.dropout = torch.nn.Dropout1d(p=dropout)
self.target_modality_conv = torch.nn.Conv1d(
channels_extension+ch_in, ch_out, 1, bias=bias) # conv1x1 channel mixer
def forward(self, x_in: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward UNET pass
```
(1 , 2048)----------------->(24 , 2048) > (1 , 2048)
v ^
(24 , 1024)----------------->(48 , 1024)
v ^
(48 , 512 )----------------->(72 , 512 )
v ^
(72 , 256 )----------------->(96 , 256 )
v ^
(96 , 128 )----BOTTLENECK--->(120, 128 )
```
"""
skipped_list = []
ds_list = [x_in]
for level, enc in enumerate(self.encoder_list):
x_skip, x_ds = enc(ds_list[-1])
skipped_list.append(x_skip)
ds_list.append(x_ds.clone())
# print(x_skip.shape, x_ds.shape)
x_dec = self.bottleneck(ds_list[-1])
for level, dec in enumerate(self.decoder_list[::-1]):
x_dec = dec(x_dec, skipped_list[-1-level])
# print(x_dec.shape)
x_dec = torch.cat([x_dec, x_in], dim=1)
# print(x_dec.shape)
x_dec = self.dropout(x_dec)
demuxed = self.target_modality_conv(x_dec)
# print(demuxed.shape)
if self.need_split:
return torch.chunk(demuxed, self.ch_out, dim=1)
return demuxed, None
# x_skip, x_ds
# (24, 2048), (24, 1024)
# (48, 1024), (48, 512 )
# (72, 512 ), (72, 256 )
# (96, 256 ), (96, 128 )
# (120, 128 )
# (96 , 256 )
# (72 , 512 )
# (48 , 1024)
# (24 , 2048)
# (25 , 2048) demuxed - after concat
# (1 , 2048)
if __name__ == "__main__":
model = WaveUNet(ch_out=1, num_layers=9)
inp = torch.rand(2, 1, 2048)
out = model(inp)
print(model)
print(model.count_parameters())
print(out[0].shape)