Wave_U_Net_audio / model /waveunet.py
hieupt's picture
Upload waveunet.py
cce0a91 verified
import torch
import torch.nn as nn
from model.crop import centre_crop
from model.resample import Resample1d
from model.conv import ConvLayer
class UpsamplingBlock(nn.Module):
def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res):
super(UpsamplingBlock, self).__init__()
assert(stride > 1)
# CONV 1 for UPSAMPLING
if res == "fixed":
self.upconv = Resample1d(n_inputs, 15, stride, transpose=True)
else:
self.upconv = ConvLayer(n_inputs, n_inputs, kernel_size, stride, conv_type, transpose=True)
self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_outputs, kernel_size, 1, conv_type)] +
[ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)])
# CONVS to combine high- with low-level information (from shortcut)
self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_outputs + n_shortcut, n_outputs, kernel_size, 1, conv_type)] +
[ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)])
def forward(self, x, shortcut):
# UPSAMPLE HIGH-LEVEL FEATURES
upsampled = self.upconv(x)
for conv in self.pre_shortcut_convs:
upsampled = conv(upsampled)
# Prepare shortcut connection
combined = centre_crop(shortcut, upsampled)
# Combine high- and low-level features
for conv in self.post_shortcut_convs:
combined = conv(torch.cat([combined, centre_crop(upsampled, combined)], dim=1))
return combined
def get_output_size(self, input_size):
curr_size = self.upconv.get_output_size(input_size)
# Upsampling convs
for conv in self.pre_shortcut_convs:
curr_size = conv.get_output_size(curr_size)
# Combine convolutions
for conv in self.post_shortcut_convs:
curr_size = conv.get_output_size(curr_size)
return curr_size
class DownsamplingBlock(nn.Module):
def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res):
super(DownsamplingBlock, self).__init__()
assert(stride > 1)
self.kernel_size = kernel_size
self.stride = stride
# CONV 1
self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_shortcut, kernel_size, 1, conv_type)] +
[ConvLayer(n_shortcut, n_shortcut, kernel_size, 1, conv_type) for _ in range(depth - 1)])
self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_shortcut, n_outputs, kernel_size, 1, conv_type)] +
[ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in
range(depth - 1)])
# CONV 2 with decimation
if res == "fixed":
self.downconv = Resample1d(n_outputs, 15, stride) # Resampling with fixed-size sinc lowpass filter
else:
self.downconv = ConvLayer(n_outputs, n_outputs, kernel_size, stride, conv_type)
def forward(self, x):
# PREPARING SHORTCUT FEATURES
shortcut = x
for conv in self.pre_shortcut_convs:
shortcut = conv(shortcut)
# PREPARING FOR DOWNSAMPLING
out = shortcut
for conv in self.post_shortcut_convs:
out = conv(out)
# DOWNSAMPLING
out = self.downconv(out)
return out, shortcut
def get_input_size(self, output_size):
curr_size = self.downconv.get_input_size(output_size)
for conv in reversed(self.post_shortcut_convs):
curr_size = conv.get_input_size(curr_size)
for conv in reversed(self.pre_shortcut_convs):
curr_size = conv.get_input_size(curr_size)
return curr_size
class Waveunet(nn.Module):
def __init__(self, num_inputs, num_channels, num_outputs, instruments, kernel_size, target_output_size, conv_type, res, separate=False, depth=1, strides=2):
super(Waveunet, self).__init__()
self.num_levels = len(num_channels)
self.strides = strides
self.kernel_size = kernel_size
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.depth = depth
self.instruments = instruments
self.separate = separate
# Only odd filter kernels allowed
assert(kernel_size % 2 == 1)
self.waveunets = nn.ModuleDict()
model_list = instruments if separate else ["ALL"]
# Create a model for each source if we separate sources separately, otherwise only one (model_list=["ALL"])
for instrument in model_list:
module = nn.Module()
module.downsampling_blocks = nn.ModuleList()
module.upsampling_blocks = nn.ModuleList()
for i in range(self.num_levels - 1):
in_ch = num_inputs if i == 0 else num_channels[i]
module.downsampling_blocks.append(
DownsamplingBlock(in_ch, num_channels[i], num_channels[i+1], kernel_size, strides, depth, conv_type, res))
for i in range(0, self.num_levels - 1):
module.upsampling_blocks.append(
UpsamplingBlock(num_channels[-1-i], num_channels[-2-i], num_channels[-2-i], kernel_size, strides, depth, conv_type, res))
module.bottlenecks = nn.ModuleList(
[ConvLayer(num_channels[-1], num_channels[-1], kernel_size, 1, conv_type) for _ in range(depth)])
# Output conv
outputs = num_outputs if separate else num_outputs * len(instruments)
module.output_conv = nn.Conv1d(num_channels[0], outputs, 1)
self.waveunets[instrument] = module
self.set_output_size(target_output_size)
def set_output_size(self, target_output_size):
self.target_output_size = target_output_size
self.input_size, self.output_size = self.check_padding(target_output_size)
print("Using valid convolutions with " + str(self.input_size) + " inputs and " + str(self.output_size) + " outputs")
assert((self.input_size - self.output_size) % 2 == 0)
self.shapes = {"output_start_frame" : (self.input_size - self.output_size) // 2,
"output_end_frame" : (self.input_size - self.output_size) // 2 + self.output_size,
"output_frames" : self.output_size,
"input_frames" : self.input_size}
def check_padding(self, target_output_size):
# Ensure number of outputs covers a whole number of cycles so each output in the cycle is weighted equally during training
bottleneck = 1
while True:
out = self.check_padding_for_bottleneck(bottleneck, target_output_size)
if out is not False:
return out
bottleneck += 1
def check_padding_for_bottleneck(self, bottleneck, target_output_size):
module = self.waveunets[[k for k in self.waveunets.keys()][0]]
try:
curr_size = bottleneck
for idx, block in enumerate(module.upsampling_blocks):
curr_size = block.get_output_size(curr_size)
output_size = curr_size
# Bottleneck-Conv
curr_size = bottleneck
for block in reversed(module.bottlenecks):
curr_size = block.get_input_size(curr_size)
for idx, block in enumerate(reversed(module.downsampling_blocks)):
curr_size = block.get_input_size(curr_size)
assert(output_size >= target_output_size)
return curr_size, output_size
except AssertionError as e:
return False
def forward_module(self, x, module):
'''
A forward pass through a single Wave-U-Net (multiple Wave-U-Nets might be used, one for each source)
:param x: Input mix
:param module: Network module to be used for prediction
:return: Source estimates
'''
shortcuts = []
out = x
# DOWNSAMPLING BLOCKS
for block in module.downsampling_blocks:
out, short = block(out)
shortcuts.append(short)
# BOTTLENECK CONVOLUTION
for conv in module.bottlenecks:
out = conv(out)
# UPSAMPLING BLOCKS
for idx, block in enumerate(module.upsampling_blocks):
out = block(out, shortcuts[-1 - idx])
# OUTPUT CONV
out = module.output_conv(out)
if not self.training: # At test time clip predictions to valid amplitude range
out = out.clamp(min=-1.0, max=1.0)
return out
def forward(self, x, inst=None):
curr_input_size = x.shape[-1]
assert(curr_input_size == self.input_size) # User promises to feed the proper input himself, to get the pre-calculated (NOT the originally desired) output size
if self.separate:
return {inst : self.forward_module(x, self.waveunets[inst])}
else:
assert(len(self.waveunets) == 1)
out = self.forward_module(x, self.waveunets["ALL"])
out_dict = {}
for idx, inst in enumerate(self.instruments):
out_dict[inst] = out[:, idx * self.num_outputs:(idx + 1) * self.num_outputs]
return out_dict