import torch from gyraudio.audio_separation.architecture.model import SeparationModel from typing import Optional, Tuple def get_non_linearity(activation: str): if activation == "LeakyReLU": non_linearity = torch.nn.LeakyReLU() else: non_linearity = torch.nn.ReLU() return non_linearity class BaseConvolutionBlock(torch.nn.Module): def __init__(self, ch_in, ch_out: int, k_size: int, activation="LeakyReLU", dropout: float = 0, bias: bool = True) -> None: super().__init__() self.conv = torch.nn.Conv1d(ch_in, ch_out, k_size, padding=k_size//2, bias=bias) self.non_linearity = get_non_linearity(activation) self.dropout = torch.nn.Dropout1d(p=dropout) def forward(self, x_in: torch.Tensor) -> torch.Tensor: x = self.conv(x_in) # [N, ch_in, T] -> [N, ch_in+channels_extension, T] x = self.non_linearity(x) x = self.dropout(x) return x class EncoderStage(torch.nn.Module): """Conv (and extend channels), downsample 2 by skipping samples """ def __init__(self, ch_in: int, ch_out: int, k_size: int = 15, dropout: float = 0, bias: bool = True) -> None: super().__init__() self.conv_block = BaseConvolutionBlock(ch_in, ch_out, k_size=k_size, dropout=dropout, bias=bias) def forward(self, x): x = self.conv_block(x) x_ds = x[..., ::2] # ch_out = ch_in+channels_extension return x, x_ds class DecoderStage(torch.nn.Module): """Upsample by 2, Concatenate with skip connection, Conv (and shrink channels) """ def __init__(self, ch_in: int, ch_out: int, k_size: int = 5, dropout: float = 0., bias: bool = True) -> None: """Decoder stage """ super().__init__() self.conv_block = BaseConvolutionBlock(ch_in, ch_out, k_size=k_size, dropout=dropout, bias=bias) self.upsample = torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=True) def forward(self, x_ds: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor: """""" x_us = self.upsample(x_ds) # [N, ch, T/2] -> [N, ch, T] x = torch.cat([x_us, x_skip], dim=1) # [N, 2.ch, T] x = self.conv_block(x) # [N, ch_out, T] return x class WaveUNet(SeparationModel): """UNET in temporal domain (waveform) = Multiscale convolutional neural network for audio separation https://arxiv.org/abs/1806.03185 """ def __init__(self, ch_in: int = 1, ch_out: int = 2, channels_extension: int = 24, k_conv_ds: int = 15, k_conv_us: int = 5, num_layers: int = 6, dropout: float = 0.0, bias: bool = True, ) -> None: super().__init__() self.need_split = ch_out != ch_in self.ch_out = ch_out self.encoder_list = torch.nn.ModuleList() self.decoder_list = torch.nn.ModuleList() # Defining first encoder self.encoder_list.append(EncoderStage(ch_in, channels_extension, k_size=k_conv_ds, dropout=dropout, bias=bias)) for level in range(1, num_layers+1): ch_i = level*channels_extension ch_o = (level+1)*channels_extension if level < num_layers: # Skipping last encoder since we defined the first one outside the loop self.encoder_list.append(EncoderStage(ch_i, ch_o, k_size=k_conv_ds, dropout=dropout, bias=bias)) self.decoder_list.append(DecoderStage(ch_o+ch_i, ch_i, k_size=k_conv_us, dropout=dropout, bias=bias)) self.bottleneck = BaseConvolutionBlock( num_layers*channels_extension, (num_layers+1)*channels_extension, k_size=k_conv_ds, dropout=dropout, bias=bias) self.dropout = torch.nn.Dropout1d(p=dropout) self.target_modality_conv = torch.nn.Conv1d( channels_extension+ch_in, ch_out, 1, bias=bias) # conv1x1 channel mixer def forward(self, x_in: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Forward UNET pass ``` (1 , 2048)----------------->(24 , 2048) > (1 , 2048) v ^ (24 , 1024)----------------->(48 , 1024) v ^ (48 , 512 )----------------->(72 , 512 ) v ^ (72 , 256 )----------------->(96 , 256 ) v ^ (96 , 128 )----BOTTLENECK--->(120, 128 ) ``` """ skipped_list = [] ds_list = [x_in] for level, enc in enumerate(self.encoder_list): x_skip, x_ds = enc(ds_list[-1]) skipped_list.append(x_skip) ds_list.append(x_ds.clone()) # print(x_skip.shape, x_ds.shape) x_dec = self.bottleneck(ds_list[-1]) for level, dec in enumerate(self.decoder_list[::-1]): x_dec = dec(x_dec, skipped_list[-1-level]) # print(x_dec.shape) x_dec = torch.cat([x_dec, x_in], dim=1) # print(x_dec.shape) x_dec = self.dropout(x_dec) demuxed = self.target_modality_conv(x_dec) # print(demuxed.shape) if self.need_split: return torch.chunk(demuxed, self.ch_out, dim=1) return demuxed, None # x_skip, x_ds # (24, 2048), (24, 1024) # (48, 1024), (48, 512 ) # (72, 512 ), (72, 256 ) # (96, 256 ), (96, 128 ) # (120, 128 ) # (96 , 256 ) # (72 , 512 ) # (48 , 1024) # (24 , 2048) # (25 , 2048) demuxed - after concat # (1 , 2048) if __name__ == "__main__": model = WaveUNet(ch_out=1, num_layers=9) inp = torch.rand(2, 1, 2048) out = model(inp) print(model) print(model.count_parameters()) print(out[0].shape)