Spaces:
Building
Building
File size: 5,955 Bytes
f6b56a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
from gyraudio.audio_separation.architecture.model import SeparationModel
from typing import Optional, Tuple
def get_non_linearity(activation: str):
if activation == "LeakyReLU":
non_linearity = torch.nn.LeakyReLU()
else:
non_linearity = torch.nn.ReLU()
return non_linearity
class BaseConvolutionBlock(torch.nn.Module):
def __init__(self, ch_in, ch_out: int, k_size: int, activation="LeakyReLU", dropout: float = 0, bias: bool = True) -> None:
super().__init__()
self.conv = torch.nn.Conv1d(ch_in, ch_out, k_size, padding=k_size//2, bias=bias)
self.non_linearity = get_non_linearity(activation)
self.dropout = torch.nn.Dropout1d(p=dropout)
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
x = self.conv(x_in) # [N, ch_in, T] -> [N, ch_in+channels_extension, T]
x = self.non_linearity(x)
x = self.dropout(x)
return x
class EncoderStage(torch.nn.Module):
"""Conv (and extend channels), downsample 2 by skipping samples
"""
def __init__(self, ch_in: int, ch_out: int, k_size: int = 15, dropout: float = 0, bias: bool = True) -> None:
super().__init__()
self.conv_block = BaseConvolutionBlock(ch_in, ch_out, k_size=k_size, dropout=dropout, bias=bias)
def forward(self, x):
x = self.conv_block(x)
x_ds = x[..., ::2]
# ch_out = ch_in+channels_extension
return x, x_ds
class DecoderStage(torch.nn.Module):
"""Upsample by 2, Concatenate with skip connection, Conv (and shrink channels)
"""
def __init__(self, ch_in: int, ch_out: int, k_size: int = 5, dropout: float = 0., bias: bool = True) -> None:
"""Decoder stage
"""
super().__init__()
self.conv_block = BaseConvolutionBlock(ch_in, ch_out, k_size=k_size, dropout=dropout, bias=bias)
self.upsample = torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
def forward(self, x_ds: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
""""""
x_us = self.upsample(x_ds) # [N, ch, T/2] -> [N, ch, T]
x = torch.cat([x_us, x_skip], dim=1) # [N, 2.ch, T]
x = self.conv_block(x) # [N, ch_out, T]
return x
class WaveUNet(SeparationModel):
"""UNET in temporal domain (waveform)
= Multiscale convolutional neural network for audio separation
https://arxiv.org/abs/1806.03185
"""
def __init__(self,
ch_in: int = 1,
ch_out: int = 2,
channels_extension: int = 24,
k_conv_ds: int = 15,
k_conv_us: int = 5,
num_layers: int = 6,
dropout: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
self.need_split = ch_out != ch_in
self.ch_out = ch_out
self.encoder_list = torch.nn.ModuleList()
self.decoder_list = torch.nn.ModuleList()
# Defining first encoder
self.encoder_list.append(EncoderStage(ch_in, channels_extension, k_size=k_conv_ds, dropout=dropout, bias=bias))
for level in range(1, num_layers+1):
ch_i = level*channels_extension
ch_o = (level+1)*channels_extension
if level < num_layers:
# Skipping last encoder since we defined the first one outside the loop
self.encoder_list.append(EncoderStage(ch_i, ch_o, k_size=k_conv_ds, dropout=dropout, bias=bias))
self.decoder_list.append(DecoderStage(ch_o+ch_i, ch_i, k_size=k_conv_us, dropout=dropout, bias=bias))
self.bottleneck = BaseConvolutionBlock(
num_layers*channels_extension,
(num_layers+1)*channels_extension,
k_size=k_conv_ds,
dropout=dropout,
bias=bias)
self.dropout = torch.nn.Dropout1d(p=dropout)
self.target_modality_conv = torch.nn.Conv1d(
channels_extension+ch_in, ch_out, 1, bias=bias) # conv1x1 channel mixer
def forward(self, x_in: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward UNET pass
```
(1 , 2048)----------------->(24 , 2048) > (1 , 2048)
v ^
(24 , 1024)----------------->(48 , 1024)
v ^
(48 , 512 )----------------->(72 , 512 )
v ^
(72 , 256 )----------------->(96 , 256 )
v ^
(96 , 128 )----BOTTLENECK--->(120, 128 )
```
"""
skipped_list = []
ds_list = [x_in]
for level, enc in enumerate(self.encoder_list):
x_skip, x_ds = enc(ds_list[-1])
skipped_list.append(x_skip)
ds_list.append(x_ds.clone())
# print(x_skip.shape, x_ds.shape)
x_dec = self.bottleneck(ds_list[-1])
for level, dec in enumerate(self.decoder_list[::-1]):
x_dec = dec(x_dec, skipped_list[-1-level])
# print(x_dec.shape)
x_dec = torch.cat([x_dec, x_in], dim=1)
# print(x_dec.shape)
x_dec = self.dropout(x_dec)
demuxed = self.target_modality_conv(x_dec)
# print(demuxed.shape)
if self.need_split:
return torch.chunk(demuxed, self.ch_out, dim=1)
return demuxed, None
# x_skip, x_ds
# (24, 2048), (24, 1024)
# (48, 1024), (48, 512 )
# (72, 512 ), (72, 256 )
# (96, 256 ), (96, 128 )
# (120, 128 )
# (96 , 256 )
# (72 , 512 )
# (48 , 1024)
# (24 , 2048)
# (25 , 2048) demuxed - after concat
# (1 , 2048)
if __name__ == "__main__":
model = WaveUNet(ch_out=1, num_layers=9)
inp = torch.rand(2, 1, 2048)
out = model(inp)
print(model)
print(model.count_parameters())
print(out[0].shape)
|