Spaces:
Building
Building
File size: 5,351 Bytes
f6b56a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
from gyraudio.audio_separation.architecture.model import SeparationModel
from gyraudio.audio_separation.architecture.building_block import ResConvolution
from typing import Optional
# import logging
class EncoderSingleStage(torch.nn.Module):
"""
Extend channels
Resnet
Downsample by 2
"""
def __init__(self, ch: int, ch_out: int, hdim: Optional[int] = None, k_size=5):
# ch_out ~ ch_in*extension_factor
super().__init__()
hdim = hdim or ch
self.extension_conv = torch.nn.Conv1d(ch, ch_out, k_size, padding=k_size//2)
self.res_conv = ResConvolution(ch_out, hdim=hdim, k_size=k_size)
# warning on maxpooling jitter offset!
self.max_pool = torch.nn.MaxPool1d(kernel_size=2)
def forward(self, x):
x = self.extension_conv(x)
x = self.res_conv(x)
x_ds = self.max_pool(x)
return x, x_ds
class DecoderSingleStage(torch.nn.Module):
"""
Upsample by 2
Resnet
Extend channels
"""
def __init__(self, ch: int, ch_out: int, hdim: Optional[int] = None, k_size=5):
"""Decoder stage
Args:
ch (int): channel size (downsampled & skip connection have same channel size)
ch_out (int): number of output channels (shall match the number of input channels of the next stage)
hdim (Optional[int], optional): Hidden dimension used in the residual block. Defaults to None.
k_size (int, optional): Convolution size. Defaults to 5.
Notes:
======
ch_out = 2*ch/extension_factor
self.scale_mixers_conv
- tells how lower decoded (x_ds) scale is merged with current encoded scale (x_skip)
- could be a pointwise aka conv1x1
"""
super().__init__()
hdim = hdim or ch
self.scale_mixers_conv = torch.nn.Conv1d(2*ch, ch_out, k_size, padding=k_size//2)
self.res_conv = ResConvolution(ch_out, hdim=hdim, k_size=k_size)
# warning: Linear interpolation shall be "conjugated" with the skipping downsampling
# special care shall be taken care of regarding offsets
# https://arxiv.org/abs/1806.03185
self.upsample = torch.nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
self.non_linearity = torch.nn.ReLU()
def forward(self, x_ds: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
""""""
x_us = self.upsample(x_ds) # [N, ch, T/2] -> [N, ch, T]
x = torch.cat([x_us, x_skip], dim=1) # [N, 2.ch, T]
x = self.scale_mixers_conv(x) # [N, ch_out, T]
x = self.non_linearity(x)
x = self.res_conv(x) # [N, ch_out, T]
return x
class ResUNet(SeparationModel):
"""Convolutional neural network for audio separation,
Decimation, bottleneck
"""
def __init__(self,
ch_in: int = 1,
ch_out: int = 2,
channels_extension: float = 1.5,
h_dim=16,
k_size=5,
) -> None:
super().__init__()
self.need_split = ch_out != ch_in
self.ch_out = ch_out
self.source_modality_conv = torch.nn.Conv1d(ch_in, h_dim, k_size, padding=k_size//2)
self.encoder_list = torch.nn.ModuleList()
self.decoder_list = torch.nn.ModuleList()
self.non_linearity = torch.nn.ReLU()
h_dim_current = h_dim
for _level in range(4):
h_dim_ds = int(h_dim_current*channels_extension)
self.encoder_list.append(EncoderSingleStage(h_dim_current, h_dim_ds, k_size=k_size))
self.decoder_list.append(DecoderSingleStage(h_dim_ds, h_dim_current, k_size=k_size))
h_dim_current = h_dim_ds
self.bottleneck = ResConvolution(h_dim_current, k_size=k_size)
self.target_modality_conv = torch.nn.Conv1d(h_dim, ch_out, 1) # conv1x1 channel mixer
def forward(self, x_in):
# x_in (1, 2048)
x0 = self.source_modality_conv(x_in)
x0 = self.non_linearity(x0)
# x0 -> (16, 2048)
x1_skip, x1_ds = self.encoder_list[0](x0)
# x1_skip -> (24, 2048)
# x1_ds -> (24, 1024)
# print(x1_skip.shape, x1_ds.shape)
x2_skip, x2_ds = self.encoder_list[1](x1_ds)
# x2_skip -> (36, 1024)
# x2_ds -> (36, 512)
# print(x2_skip.shape, x2_ds.shape)
x3_skip, x3_ds = self.encoder_list[2](x2_ds)
# x3_skip -> (54, 512)
# x3_ds -> (54, 256)
# print(x3_skip.shape, x3_ds.shape)
x4_skip, x4_ds = self.encoder_list[3](x3_ds)
# x4_skip -> (81, 256)
# x4_ds -> (81, 128)
# print(x4_skip.shape, x4_ds.shape)
x4_dec = self.bottleneck(x4_ds)
x3_dec = self.decoder_list[3](x4_dec, x4_skip)
x2_dec = self.decoder_list[2](x3_dec, x3_skip)
x1_dec = self.decoder_list[1](x2_dec, x2_skip)
x0_dec = self.decoder_list[0](x1_dec, x1_skip)
demuxed = self.target_modality_conv(x0_dec)
# no relu
if self.need_split:
return torch.chunk(demuxed, self.ch_out, dim=1)
return demuxed, None
if __name__ == "__main__":
model = ResUNet()
inp = torch.rand(2, 1, 2048)
out = model(inp)
print(model)
print(model.count_parameters())
print(out[0].shape)
|