Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from model.crop import centre_crop | |
from model.resample import Resample1d | |
from model.conv import ConvLayer | |
class UpsamplingBlock(nn.Module): | |
def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res): | |
super(UpsamplingBlock, self).__init__() | |
assert(stride > 1) | |
# CONV 1 for UPSAMPLING | |
if res == "fixed": | |
self.upconv = Resample1d(n_inputs, 15, stride, transpose=True) | |
else: | |
self.upconv = ConvLayer(n_inputs, n_inputs, kernel_size, stride, conv_type, transpose=True) | |
self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_outputs, kernel_size, 1, conv_type)] + | |
[ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)]) | |
# CONVS to combine high- with low-level information (from shortcut) | |
self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_outputs + n_shortcut, n_outputs, kernel_size, 1, conv_type)] + | |
[ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)]) | |
def forward(self, x, shortcut): | |
# UPSAMPLE HIGH-LEVEL FEATURES | |
upsampled = self.upconv(x) | |
for conv in self.pre_shortcut_convs: | |
upsampled = conv(upsampled) | |
# Prepare shortcut connection | |
combined = centre_crop(shortcut, upsampled) | |
# Combine high- and low-level features | |
for conv in self.post_shortcut_convs: | |
combined = conv(torch.cat([combined, centre_crop(upsampled, combined)], dim=1)) | |
return combined | |
def get_output_size(self, input_size): | |
curr_size = self.upconv.get_output_size(input_size) | |
# Upsampling convs | |
for conv in self.pre_shortcut_convs: | |
curr_size = conv.get_output_size(curr_size) | |
# Combine convolutions | |
for conv in self.post_shortcut_convs: | |
curr_size = conv.get_output_size(curr_size) | |
return curr_size | |
class DownsamplingBlock(nn.Module): | |
def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res): | |
super(DownsamplingBlock, self).__init__() | |
assert(stride > 1) | |
self.kernel_size = kernel_size | |
self.stride = stride | |
# CONV 1 | |
self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_shortcut, kernel_size, 1, conv_type)] + | |
[ConvLayer(n_shortcut, n_shortcut, kernel_size, 1, conv_type) for _ in range(depth - 1)]) | |
self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_shortcut, n_outputs, kernel_size, 1, conv_type)] + | |
[ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in | |
range(depth - 1)]) | |
# CONV 2 with decimation | |
if res == "fixed": | |
self.downconv = Resample1d(n_outputs, 15, stride) # Resampling with fixed-size sinc lowpass filter | |
else: | |
self.downconv = ConvLayer(n_outputs, n_outputs, kernel_size, stride, conv_type) | |
def forward(self, x): | |
# PREPARING SHORTCUT FEATURES | |
shortcut = x | |
for conv in self.pre_shortcut_convs: | |
shortcut = conv(shortcut) | |
# PREPARING FOR DOWNSAMPLING | |
out = shortcut | |
for conv in self.post_shortcut_convs: | |
out = conv(out) | |
# DOWNSAMPLING | |
out = self.downconv(out) | |
return out, shortcut | |
def get_input_size(self, output_size): | |
curr_size = self.downconv.get_input_size(output_size) | |
for conv in reversed(self.post_shortcut_convs): | |
curr_size = conv.get_input_size(curr_size) | |
for conv in reversed(self.pre_shortcut_convs): | |
curr_size = conv.get_input_size(curr_size) | |
return curr_size | |
class Waveunet(nn.Module): | |
def __init__(self, num_inputs, num_channels, num_outputs, instruments, kernel_size, target_output_size, conv_type, res, separate=False, depth=1, strides=2): | |
super(Waveunet, self).__init__() | |
self.num_levels = len(num_channels) | |
self.strides = strides | |
self.kernel_size = kernel_size | |
self.num_inputs = num_inputs | |
self.num_outputs = num_outputs | |
self.depth = depth | |
self.instruments = instruments | |
self.separate = separate | |
# Only odd filter kernels allowed | |
assert(kernel_size % 2 == 1) | |
self.waveunets = nn.ModuleDict() | |
model_list = instruments if separate else ["ALL"] | |
# Create a model for each source if we separate sources separately, otherwise only one (model_list=["ALL"]) | |
for instrument in model_list: | |
module = nn.Module() | |
module.downsampling_blocks = nn.ModuleList() | |
module.upsampling_blocks = nn.ModuleList() | |
for i in range(self.num_levels - 1): | |
in_ch = num_inputs if i == 0 else num_channels[i] | |
module.downsampling_blocks.append( | |
DownsamplingBlock(in_ch, num_channels[i], num_channels[i+1], kernel_size, strides, depth, conv_type, res)) | |
for i in range(0, self.num_levels - 1): | |
module.upsampling_blocks.append( | |
UpsamplingBlock(num_channels[-1-i], num_channels[-2-i], num_channels[-2-i], kernel_size, strides, depth, conv_type, res)) | |
module.bottlenecks = nn.ModuleList( | |
[ConvLayer(num_channels[-1], num_channels[-1], kernel_size, 1, conv_type) for _ in range(depth)]) | |
# Output conv | |
outputs = num_outputs if separate else num_outputs * len(instruments) | |
module.output_conv = nn.Conv1d(num_channels[0], outputs, 1) | |
self.waveunets[instrument] = module | |
self.set_output_size(target_output_size) | |
def set_output_size(self, target_output_size): | |
self.target_output_size = target_output_size | |
self.input_size, self.output_size = self.check_padding(target_output_size) | |
print("Using valid convolutions with " + str(self.input_size) + " inputs and " + str(self.output_size) + " outputs") | |
assert((self.input_size - self.output_size) % 2 == 0) | |
self.shapes = {"output_start_frame" : (self.input_size - self.output_size) // 2, | |
"output_end_frame" : (self.input_size - self.output_size) // 2 + self.output_size, | |
"output_frames" : self.output_size, | |
"input_frames" : self.input_size} | |
def check_padding(self, target_output_size): | |
# Ensure number of outputs covers a whole number of cycles so each output in the cycle is weighted equally during training | |
bottleneck = 1 | |
while True: | |
out = self.check_padding_for_bottleneck(bottleneck, target_output_size) | |
if out is not False: | |
return out | |
bottleneck += 1 | |
def check_padding_for_bottleneck(self, bottleneck, target_output_size): | |
module = self.waveunets[[k for k in self.waveunets.keys()][0]] | |
try: | |
curr_size = bottleneck | |
for idx, block in enumerate(module.upsampling_blocks): | |
curr_size = block.get_output_size(curr_size) | |
output_size = curr_size | |
# Bottleneck-Conv | |
curr_size = bottleneck | |
for block in reversed(module.bottlenecks): | |
curr_size = block.get_input_size(curr_size) | |
for idx, block in enumerate(reversed(module.downsampling_blocks)): | |
curr_size = block.get_input_size(curr_size) | |
assert(output_size >= target_output_size) | |
return curr_size, output_size | |
except AssertionError as e: | |
return False | |
def forward_module(self, x, module): | |
''' | |
A forward pass through a single Wave-U-Net (multiple Wave-U-Nets might be used, one for each source) | |
:param x: Input mix | |
:param module: Network module to be used for prediction | |
:return: Source estimates | |
''' | |
shortcuts = [] | |
out = x | |
# DOWNSAMPLING BLOCKS | |
for block in module.downsampling_blocks: | |
out, short = block(out) | |
shortcuts.append(short) | |
# BOTTLENECK CONVOLUTION | |
for conv in module.bottlenecks: | |
out = conv(out) | |
# UPSAMPLING BLOCKS | |
for idx, block in enumerate(module.upsampling_blocks): | |
out = block(out, shortcuts[-1 - idx]) | |
# OUTPUT CONV | |
out = module.output_conv(out) | |
if not self.training: # At test time clip predictions to valid amplitude range | |
out = out.clamp(min=-1.0, max=1.0) | |
return out | |
def forward(self, x, inst=None): | |
curr_input_size = x.shape[-1] | |
assert(curr_input_size == self.input_size) # User promises to feed the proper input himself, to get the pre-calculated (NOT the originally desired) output size | |
if self.separate: | |
return {inst : self.forward_module(x, self.waveunets[inst])} | |
else: | |
assert(len(self.waveunets) == 1) | |
out = self.forward_module(x, self.waveunets["ALL"]) | |
out_dict = {} | |
for idx, inst in enumerate(self.instruments): | |
out_dict[inst] = out[:, idx * self.num_outputs:(idx + 1) * self.num_outputs] | |
return out_dict |