import torch.nn as nn import torch import numpy as np import torch.nn.functional as F import math from torchlibrosa.stft import magphase def init_layer(layer): """Initialize a Linear or Convolutional layer. """ nn.init.xavier_uniform_(layer.weight) if hasattr(layer, "bias"): if layer.bias is not None: layer.bias.data.fill_(0.0) def init_bn(bn): """Initialize a Batchnorm layer. """ bn.bias.data.fill_(0.0) bn.weight.data.fill_(1.0) def init_embedding(layer): """Initialize a Linear or Convolutional layer. """ nn.init.uniform_(layer.weight, -1., 1.) if hasattr(layer, 'bias'): if layer.bias is not None: layer.bias.data.fill_(0.) def init_gru(rnn): """Initialize a GRU layer. """ def _concat_init(tensor, init_funcs): (length, fan_out) = tensor.shape fan_in = length // len(init_funcs) for (i, init_func) in enumerate(init_funcs): init_func(tensor[i * fan_in : (i + 1) * fan_in, :]) def _inner_uniform(tensor): fan_in = nn.init._calculate_correct_fan(tensor, "fan_in") nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) for i in range(rnn.num_layers): _concat_init( getattr(rnn, "weight_ih_l{}".format(i)), [_inner_uniform, _inner_uniform, _inner_uniform], ) torch.nn.init.constant_(getattr(rnn, "bias_ih_l{}".format(i)), 0) _concat_init( getattr(rnn, "weight_hh_l{}".format(i)), [_inner_uniform, _inner_uniform, nn.init.orthogonal_], ) torch.nn.init.constant_(getattr(rnn, "bias_hh_l{}".format(i)), 0) def act(x, activation): if activation == "relu": return F.relu_(x) elif activation == "leaky_relu": return F.leaky_relu_(x, negative_slope=0.01) elif activation == "swish": return x * torch.sigmoid(x) else: raise Exception("Incorrect activation!") class Base: def __init__(self): pass def spectrogram(self, input, eps=0.): (real, imag) = self.stft(input) return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 def spectrogram_phase(self, input, eps=0.): (real, imag) = self.stft(input) mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 cos = real / mag sin = imag / mag return mag, cos, sin def wav_to_spectrogram_phase(self, input, eps=1e-10): """Waveform to spectrogram. Args: input: (batch_size, segment_samples, channels_num) Outputs: output: (batch_size, channels_num, time_steps, freq_bins) """ sp_list = [] cos_list = [] sin_list = [] channels_num = input.shape[1] for channel in range(channels_num): mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps) sp_list.append(mag) cos_list.append(cos) sin_list.append(sin) sps = torch.cat(sp_list, dim=1) coss = torch.cat(cos_list, dim=1) sins = torch.cat(sin_list, dim=1) return sps, coss, sins def wav_to_spectrogram(self, input, eps=0.): """Waveform to spectrogram. Args: input: (batch_size, segment_samples, channels_num) Outputs: output: (batch_size, channels_num, time_steps, freq_bins) """ sp_list = [] channels_num = input.shape[1] for channel in range(channels_num): sp_list.append(self.spectrogram(input[:, channel, :], eps=eps)) output = torch.cat(sp_list, dim=1) return output def spectrogram_to_wav(self, input, spectrogram, length=None): """Spectrogram to waveform. Args: input: (batch_size, segment_samples, channels_num) spectrogram: (batch_size, channels_num, time_steps, freq_bins) Outputs: output: (batch_size, segment_samples, channels_num) """ channels_num = input.shape[1] wav_list = [] for channel in range(channels_num): (real, imag) = self.stft(input[:, channel, :]) (_, cos, sin) = magphase(real, imag) wav_list.append(self.istft(spectrogram[:, channel : channel + 1, :, :] * cos, spectrogram[:, channel : channel + 1, :, :] * sin, length)) output = torch.stack(wav_list, dim=1) return output