AudioSep / models /base.py
badayvedat's picture
Initial commit
ae29df4
raw
history blame
4.48 kB
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
import math
from torchlibrosa.stft import magphase
def init_layer(layer):
"""Initialize a Linear or Convolutional layer. """
nn.init.xavier_uniform_(layer.weight)
if hasattr(layer, "bias"):
if layer.bias is not None:
layer.bias.data.fill_(0.0)
def init_bn(bn):
"""Initialize a Batchnorm layer. """
bn.bias.data.fill_(0.0)
bn.weight.data.fill_(1.0)
def init_embedding(layer):
"""Initialize a Linear or Convolutional layer. """
nn.init.uniform_(layer.weight, -1., 1.)
if hasattr(layer, 'bias'):
if layer.bias is not None:
layer.bias.data.fill_(0.)
def init_gru(rnn):
"""Initialize a GRU layer. """
def _concat_init(tensor, init_funcs):
(length, fan_out) = tensor.shape
fan_in = length // len(init_funcs)
for (i, init_func) in enumerate(init_funcs):
init_func(tensor[i * fan_in : (i + 1) * fan_in, :])
def _inner_uniform(tensor):
fan_in = nn.init._calculate_correct_fan(tensor, "fan_in")
nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
for i in range(rnn.num_layers):
_concat_init(
getattr(rnn, "weight_ih_l{}".format(i)),
[_inner_uniform, _inner_uniform, _inner_uniform],
)
torch.nn.init.constant_(getattr(rnn, "bias_ih_l{}".format(i)), 0)
_concat_init(
getattr(rnn, "weight_hh_l{}".format(i)),
[_inner_uniform, _inner_uniform, nn.init.orthogonal_],
)
torch.nn.init.constant_(getattr(rnn, "bias_hh_l{}".format(i)), 0)
def act(x, activation):
if activation == "relu":
return F.relu_(x)
elif activation == "leaky_relu":
return F.leaky_relu_(x, negative_slope=0.01)
elif activation == "swish":
return x * torch.sigmoid(x)
else:
raise Exception("Incorrect activation!")
class Base:
def __init__(self):
pass
def spectrogram(self, input, eps=0.):
(real, imag) = self.stft(input)
return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
def spectrogram_phase(self, input, eps=0.):
(real, imag) = self.stft(input)
mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
cos = real / mag
sin = imag / mag
return mag, cos, sin
def wav_to_spectrogram_phase(self, input, eps=1e-10):
"""Waveform to spectrogram.
Args:
input: (batch_size, segment_samples, channels_num)
Outputs:
output: (batch_size, channels_num, time_steps, freq_bins)
"""
sp_list = []
cos_list = []
sin_list = []
channels_num = input.shape[1]
for channel in range(channels_num):
mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps)
sp_list.append(mag)
cos_list.append(cos)
sin_list.append(sin)
sps = torch.cat(sp_list, dim=1)
coss = torch.cat(cos_list, dim=1)
sins = torch.cat(sin_list, dim=1)
return sps, coss, sins
def wav_to_spectrogram(self, input, eps=0.):
"""Waveform to spectrogram.
Args:
input: (batch_size, segment_samples, channels_num)
Outputs:
output: (batch_size, channels_num, time_steps, freq_bins)
"""
sp_list = []
channels_num = input.shape[1]
for channel in range(channels_num):
sp_list.append(self.spectrogram(input[:, channel, :], eps=eps))
output = torch.cat(sp_list, dim=1)
return output
def spectrogram_to_wav(self, input, spectrogram, length=None):
"""Spectrogram to waveform.
Args:
input: (batch_size, segment_samples, channels_num)
spectrogram: (batch_size, channels_num, time_steps, freq_bins)
Outputs:
output: (batch_size, segment_samples, channels_num)
"""
channels_num = input.shape[1]
wav_list = []
for channel in range(channels_num):
(real, imag) = self.stft(input[:, channel, :])
(_, cos, sin) = magphase(real, imag)
wav_list.append(self.istft(spectrogram[:, channel : channel + 1, :, :] * cos,
spectrogram[:, channel : channel + 1, :, :] * sin, length))
output = torch.stack(wav_list, dim=1)
return output