Spaces:
Paused
Paused
import numpy as np | |
from typing import Dict, List, NoReturn, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchlibrosa.stft import STFT, ISTFT, magphase | |
from models.base import Base, init_layer, init_bn, act | |
class FiLM(nn.Module): | |
def __init__(self, film_meta, condition_size): | |
super(FiLM, self).__init__() | |
self.condition_size = condition_size | |
self.modules, _ = self.create_film_modules( | |
film_meta=film_meta, | |
ancestor_names=[], | |
) | |
def create_film_modules(self, film_meta, ancestor_names): | |
modules = {} | |
# Pre-order traversal of modules | |
for module_name, value in film_meta.items(): | |
if isinstance(value, int): | |
ancestor_names.append(module_name) | |
unique_module_name = '->'.join(ancestor_names) | |
modules[module_name] = self.add_film_layer_to_module( | |
num_features=value, | |
unique_module_name=unique_module_name, | |
) | |
elif isinstance(value, dict): | |
ancestor_names.append(module_name) | |
modules[module_name], _ = self.create_film_modules( | |
film_meta=value, | |
ancestor_names=ancestor_names, | |
) | |
ancestor_names.pop() | |
return modules, ancestor_names | |
def add_film_layer_to_module(self, num_features, unique_module_name): | |
layer = nn.Linear(self.condition_size, num_features) | |
init_layer(layer) | |
self.add_module(name=unique_module_name, module=layer) | |
return layer | |
def forward(self, conditions): | |
film_dict = self.calculate_film_data( | |
conditions=conditions, | |
modules=self.modules, | |
) | |
return film_dict | |
def calculate_film_data(self, conditions, modules): | |
film_data = {} | |
# Pre-order traversal of modules | |
for module_name, module in modules.items(): | |
if isinstance(module, nn.Module): | |
film_data[module_name] = module(conditions)[:, :, None, None] | |
elif isinstance(module, dict): | |
film_data[module_name] = self.calculate_film_data(conditions, module) | |
return film_data | |
class ConvBlockRes(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Tuple, | |
momentum: float, | |
has_film, | |
): | |
r"""Residual block.""" | |
super(ConvBlockRes, self).__init__() | |
padding = [kernel_size[0] // 2, kernel_size[1] // 2] | |
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) | |
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum) | |
self.conv1 = nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=(1, 1), | |
dilation=(1, 1), | |
padding=padding, | |
bias=False, | |
) | |
self.conv2 = nn.Conv2d( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=(1, 1), | |
dilation=(1, 1), | |
padding=padding, | |
bias=False, | |
) | |
if in_channels != out_channels: | |
self.shortcut = nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=(1, 1), | |
stride=(1, 1), | |
padding=(0, 0), | |
) | |
self.is_shortcut = True | |
else: | |
self.is_shortcut = False | |
self.has_film = has_film | |
self.init_weights() | |
def init_weights(self) -> NoReturn: | |
r"""Initialize weights.""" | |
init_bn(self.bn1) | |
init_bn(self.bn2) | |
init_layer(self.conv1) | |
init_layer(self.conv2) | |
if self.is_shortcut: | |
init_layer(self.shortcut) | |
def forward(self, input_tensor: torch.Tensor, film_dict: Dict) -> torch.Tensor: | |
r"""Forward data into the module. | |
Args: | |
input_tensor: (batch_size, input_feature_maps, time_steps, freq_bins) | |
Returns: | |
output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins) | |
""" | |
b1 = film_dict['beta1'] | |
b2 = film_dict['beta2'] | |
x = self.conv1(F.leaky_relu_(self.bn1(input_tensor) + b1, negative_slope=0.01)) | |
x = self.conv2(F.leaky_relu_(self.bn2(x) + b2, negative_slope=0.01)) | |
if self.is_shortcut: | |
return self.shortcut(input_tensor) + x | |
else: | |
return input_tensor + x | |
class EncoderBlockRes1B(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Tuple, | |
downsample: Tuple, | |
momentum: float, | |
has_film, | |
): | |
r"""Encoder block, contains 8 convolutional layers.""" | |
super(EncoderBlockRes1B, self).__init__() | |
self.conv_block1 = ConvBlockRes( | |
in_channels, out_channels, kernel_size, momentum, has_film, | |
) | |
self.downsample = downsample | |
def forward(self, input_tensor: torch.Tensor, film_dict: Dict) -> torch.Tensor: | |
r"""Forward data into the module. | |
Args: | |
input_tensor: (batch_size, input_feature_maps, time_steps, freq_bins) | |
Returns: | |
encoder_pool: (batch_size, output_feature_maps, downsampled_time_steps, downsampled_freq_bins) | |
encoder: (batch_size, output_feature_maps, time_steps, freq_bins) | |
""" | |
encoder = self.conv_block1(input_tensor, film_dict['conv_block1']) | |
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample) | |
return encoder_pool, encoder | |
class DecoderBlockRes1B(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Tuple, | |
upsample: Tuple, | |
momentum: float, | |
has_film, | |
): | |
r"""Decoder block, contains 1 transposed convolutional and 8 convolutional layers.""" | |
super(DecoderBlockRes1B, self).__init__() | |
self.kernel_size = kernel_size | |
self.stride = upsample | |
self.conv1 = torch.nn.ConvTranspose2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=self.stride, | |
stride=self.stride, | |
padding=(0, 0), | |
bias=False, | |
dilation=(1, 1), | |
) | |
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) | |
self.conv_block2 = ConvBlockRes( | |
out_channels * 2, out_channels, kernel_size, momentum, has_film, | |
) | |
self.bn2 = nn.BatchNorm2d(in_channels, momentum=momentum) | |
self.has_film = has_film | |
self.init_weights() | |
def init_weights(self): | |
r"""Initialize weights.""" | |
init_bn(self.bn1) | |
init_layer(self.conv1) | |
def forward( | |
self, input_tensor: torch.Tensor, concat_tensor: torch.Tensor, film_dict: Dict, | |
) -> torch.Tensor: | |
r"""Forward data into the module. | |
Args: | |
input_tensor: (batch_size, input_feature_maps, downsampled_time_steps, downsampled_freq_bins) | |
concat_tensor: (batch_size, input_feature_maps, time_steps, freq_bins) | |
Returns: | |
output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins) | |
""" | |
# b1 = film_dict['beta1'] | |
b1 = film_dict['beta1'] | |
x = self.conv1(F.leaky_relu_(self.bn1(input_tensor) + b1)) | |
# (batch_size, input_feature_maps, time_steps, freq_bins) | |
x = torch.cat((x, concat_tensor), dim=1) | |
# (batch_size, input_feature_maps * 2, time_steps, freq_bins) | |
x = self.conv_block2(x, film_dict['conv_block2']) | |
# output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins) | |
return x | |
class ResUNet30_Base(nn.Module, Base): | |
def __init__(self, input_channels, output_channels): | |
super(ResUNet30_Base, self).__init__() | |
window_size = 2048 | |
hop_size = 320 | |
center = True | |
pad_mode = "reflect" | |
window = "hann" | |
momentum = 0.01 | |
self.output_channels = output_channels | |
self.target_sources_num = 1 | |
self.K = 3 | |
self.time_downsample_ratio = 2 ** 5 # This number equals 2^{#encoder_blcoks} | |
self.stft = STFT( | |
n_fft=window_size, | |
hop_length=hop_size, | |
win_length=window_size, | |
window=window, | |
center=center, | |
pad_mode=pad_mode, | |
freeze_parameters=True, | |
) | |
self.istft = ISTFT( | |
n_fft=window_size, | |
hop_length=hop_size, | |
win_length=window_size, | |
window=window, | |
center=center, | |
pad_mode=pad_mode, | |
freeze_parameters=True, | |
) | |
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum) | |
self.pre_conv = nn.Conv2d( | |
in_channels=input_channels, | |
out_channels=32, | |
kernel_size=(1, 1), | |
stride=(1, 1), | |
padding=(0, 0), | |
bias=True, | |
) | |
self.encoder_block1 = EncoderBlockRes1B( | |
in_channels=32, | |
out_channels=32, | |
kernel_size=(3, 3), | |
downsample=(2, 2), | |
momentum=momentum, | |
has_film=True, | |
) | |
self.encoder_block2 = EncoderBlockRes1B( | |
in_channels=32, | |
out_channels=64, | |
kernel_size=(3, 3), | |
downsample=(2, 2), | |
momentum=momentum, | |
has_film=True, | |
) | |
self.encoder_block3 = EncoderBlockRes1B( | |
in_channels=64, | |
out_channels=128, | |
kernel_size=(3, 3), | |
downsample=(2, 2), | |
momentum=momentum, | |
has_film=True, | |
) | |
self.encoder_block4 = EncoderBlockRes1B( | |
in_channels=128, | |
out_channels=256, | |
kernel_size=(3, 3), | |
downsample=(2, 2), | |
momentum=momentum, | |
has_film=True, | |
) | |
self.encoder_block5 = EncoderBlockRes1B( | |
in_channels=256, | |
out_channels=384, | |
kernel_size=(3, 3), | |
downsample=(2, 2), | |
momentum=momentum, | |
has_film=True, | |
) | |
self.encoder_block6 = EncoderBlockRes1B( | |
in_channels=384, | |
out_channels=384, | |
kernel_size=(3, 3), | |
downsample=(1, 2), | |
momentum=momentum, | |
has_film=True, | |
) | |
self.conv_block7a = EncoderBlockRes1B( | |
in_channels=384, | |
out_channels=384, | |
kernel_size=(3, 3), | |
downsample=(1, 1), | |
momentum=momentum, | |
has_film=True, | |
) | |
self.decoder_block1 = DecoderBlockRes1B( | |
in_channels=384, | |
out_channels=384, | |
kernel_size=(3, 3), | |
upsample=(1, 2), | |
momentum=momentum, | |
has_film=True, | |
) | |
self.decoder_block2 = DecoderBlockRes1B( | |
in_channels=384, | |
out_channels=384, | |
kernel_size=(3, 3), | |
upsample=(2, 2), | |
momentum=momentum, | |
has_film=True, | |
) | |
self.decoder_block3 = DecoderBlockRes1B( | |
in_channels=384, | |
out_channels=256, | |
kernel_size=(3, 3), | |
upsample=(2, 2), | |
momentum=momentum, | |
has_film=True, | |
) | |
self.decoder_block4 = DecoderBlockRes1B( | |
in_channels=256, | |
out_channels=128, | |
kernel_size=(3, 3), | |
upsample=(2, 2), | |
momentum=momentum, | |
has_film=True, | |
) | |
self.decoder_block5 = DecoderBlockRes1B( | |
in_channels=128, | |
out_channels=64, | |
kernel_size=(3, 3), | |
upsample=(2, 2), | |
momentum=momentum, | |
has_film=True, | |
) | |
self.decoder_block6 = DecoderBlockRes1B( | |
in_channels=64, | |
out_channels=32, | |
kernel_size=(3, 3), | |
upsample=(2, 2), | |
momentum=momentum, | |
has_film=True, | |
) | |
self.after_conv = nn.Conv2d( | |
in_channels=32, | |
out_channels=output_channels * self.K, | |
kernel_size=(1, 1), | |
stride=(1, 1), | |
padding=(0, 0), | |
bias=True, | |
) | |
self.init_weights() | |
def init_weights(self): | |
init_bn(self.bn0) | |
init_layer(self.pre_conv) | |
init_layer(self.after_conv) | |
def feature_maps_to_wav( | |
self, | |
input_tensor: torch.Tensor, | |
sp: torch.Tensor, | |
sin_in: torch.Tensor, | |
cos_in: torch.Tensor, | |
audio_length: int, | |
) -> torch.Tensor: | |
r"""Convert feature maps to waveform. | |
Args: | |
input_tensor: (batch_size, target_sources_num * output_channels * self.K, time_steps, freq_bins) | |
sp: (batch_size, input_channels, time_steps, freq_bins) | |
sin_in: (batch_size, input_channels, time_steps, freq_bins) | |
cos_in: (batch_size, input_channels, time_steps, freq_bins) | |
(There is input_channels == output_channels for the source separation task.) | |
Outputs: | |
waveform: (batch_size, target_sources_num * output_channels, segment_samples) | |
""" | |
batch_size, _, time_steps, freq_bins = input_tensor.shape | |
x = input_tensor.reshape( | |
batch_size, | |
self.target_sources_num, | |
self.output_channels, | |
self.K, | |
time_steps, | |
freq_bins, | |
) | |
# x: (batch_size, target_sources_num, output_channels, self.K, time_steps, freq_bins) | |
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :]) | |
_mask_real = torch.tanh(x[:, :, :, 1, :, :]) | |
_mask_imag = torch.tanh(x[:, :, :, 2, :, :]) | |
# linear_mag = torch.tanh(x[:, :, :, 3, :, :]) | |
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag) | |
# mask_cos, mask_sin: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) | |
# Y = |Y|cos∠Y + j|Y|sin∠Y | |
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M) | |
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M) | |
out_cos = ( | |
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin | |
) | |
out_sin = ( | |
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin | |
) | |
# out_cos: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) | |
# out_sin: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) | |
# Calculate |Y|. | |
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag) | |
# out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag) | |
# out_mag: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) | |
# Calculate Y_{real} and Y_{imag} for ISTFT. | |
out_real = out_mag * out_cos | |
out_imag = out_mag * out_sin | |
# out_real, out_imag: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) | |
# Reformat shape to (N, 1, time_steps, freq_bins) for ISTFT where | |
# N = batch_size * target_sources_num * output_channels | |
shape = ( | |
batch_size * self.target_sources_num * self.output_channels, | |
1, | |
time_steps, | |
freq_bins, | |
) | |
out_real = out_real.reshape(shape) | |
out_imag = out_imag.reshape(shape) | |
# ISTFT. | |
x = self.istft(out_real, out_imag, audio_length) | |
# (batch_size * target_sources_num * output_channels, segments_num) | |
# Reshape. | |
waveform = x.reshape( | |
batch_size, self.target_sources_num * self.output_channels, audio_length | |
) | |
# (batch_size, target_sources_num * output_channels, segments_num) | |
return waveform | |
def forward(self, mixtures, film_dict): | |
""" | |
Args: | |
input: (batch_size, segment_samples, channels_num) | |
Outputs: | |
output_dict: { | |
'wav': (batch_size, segment_samples, channels_num), | |
'sp': (batch_size, channels_num, time_steps, freq_bins)} | |
""" | |
mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures) | |
x = mag | |
# Batch normalization | |
x = x.transpose(1, 3) | |
x = self.bn0(x) | |
x = x.transpose(1, 3) | |
"""(batch_size, chanenls, time_steps, freq_bins)""" | |
# Pad spectrogram to be evenly divided by downsample ratio. | |
origin_len = x.shape[2] | |
pad_len = ( | |
int(np.ceil(x.shape[2] / self.time_downsample_ratio)) * self.time_downsample_ratio | |
- origin_len | |
) | |
x = F.pad(x, pad=(0, 0, 0, pad_len)) | |
"""(batch_size, channels, padded_time_steps, freq_bins)""" | |
# Let frequency bins be evenly divided by 2, e.g., 513 -> 512 | |
x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F) | |
# UNet | |
x = self.pre_conv(x) | |
x1_pool, x1 = self.encoder_block1(x, film_dict['encoder_block1']) # x1_pool: (bs, 32, T / 2, F / 2) | |
x2_pool, x2 = self.encoder_block2(x1_pool, film_dict['encoder_block2']) # x2_pool: (bs, 64, T / 4, F / 4) | |
x3_pool, x3 = self.encoder_block3(x2_pool, film_dict['encoder_block3']) # x3_pool: (bs, 128, T / 8, F / 8) | |
x4_pool, x4 = self.encoder_block4(x3_pool, film_dict['encoder_block4']) # x4_pool: (bs, 256, T / 16, F / 16) | |
x5_pool, x5 = self.encoder_block5(x4_pool, film_dict['encoder_block5']) # x5_pool: (bs, 384, T / 32, F / 32) | |
x6_pool, x6 = self.encoder_block6(x5_pool, film_dict['encoder_block6']) # x6_pool: (bs, 384, T / 32, F / 64) | |
x_center, _ = self.conv_block7a(x6_pool, film_dict['conv_block7a']) # (bs, 384, T / 32, F / 64) | |
x7 = self.decoder_block1(x_center, x6, film_dict['decoder_block1']) # (bs, 384, T / 32, F / 32) | |
x8 = self.decoder_block2(x7, x5, film_dict['decoder_block2']) # (bs, 384, T / 16, F / 16) | |
x9 = self.decoder_block3(x8, x4, film_dict['decoder_block3']) # (bs, 256, T / 8, F / 8) | |
x10 = self.decoder_block4(x9, x3, film_dict['decoder_block4']) # (bs, 128, T / 4, F / 4) | |
x11 = self.decoder_block5(x10, x2, film_dict['decoder_block5']) # (bs, 64, T / 2, F / 2) | |
x12 = self.decoder_block6(x11, x1, film_dict['decoder_block6']) # (bs, 32, T, F) | |
x = self.after_conv(x12) | |
# Recover shape | |
x = F.pad(x, pad=(0, 1)) | |
x = x[:, :, 0:origin_len, :] | |
audio_length = mixtures.shape[2] | |
# Recover each subband spectrograms to subband waveforms. Then synthesis | |
# the subband waveforms to a waveform. | |
separated_audio = self.feature_maps_to_wav( | |
input_tensor=x, | |
# input_tensor: (batch_size, target_sources_num * output_channels * self.K, T, F') | |
sp=mag, | |
# sp: (batch_size, input_channels, T, F') | |
sin_in=sin_in, | |
# sin_in: (batch_size, input_channels, T, F') | |
cos_in=cos_in, | |
# cos_in: (batch_size, input_channels, T, F') | |
audio_length=audio_length, | |
) | |
# (batch_size, target_sources_num * output_channels, subbands_num, segment_samples) | |
output_dict = {'waveform': separated_audio} | |
return output_dict | |
def get_film_meta(module): | |
film_meta = {} | |
if hasattr(module, 'has_film'):\ | |
if module.has_film: | |
film_meta['beta1'] = module.bn1.num_features | |
film_meta['beta2'] = module.bn2.num_features | |
else: | |
film_meta['beta1'] = 0 | |
film_meta['beta2'] = 0 | |
for child_name, child_module in module.named_children(): | |
child_meta = get_film_meta(child_module) | |
if len(child_meta) > 0: | |
film_meta[child_name] = child_meta | |
return film_meta | |
class ResUNet30(nn.Module): | |
def __init__(self, input_channels, output_channels, condition_size): | |
super(ResUNet30, self).__init__() | |
self.base = ResUNet30_Base( | |
input_channels=input_channels, | |
output_channels=output_channels, | |
) | |
self.film_meta = get_film_meta( | |
module=self.base, | |
) | |
self.film = FiLM( | |
film_meta=self.film_meta, | |
condition_size=condition_size | |
) | |
def forward(self, input_dict): | |
mixtures = input_dict['mixture'] | |
conditions = input_dict['condition'] | |
film_dict = self.film( | |
conditions=conditions, | |
) | |
output_dict = self.base( | |
mixtures=mixtures, | |
film_dict=film_dict, | |
) | |
return output_dict | |