HoneyTian's picture
update
a813eb0
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from einops.layers.torch import Rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pesq import pesq
from joblib import Parallel, delayed
def phase_losses(phase_r, phase_g):
ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
return ip_loss, gd_loss, iaf_loss
def anti_wrapping_function(x):
return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
def pesq_score(utts_r, utts_g, h):
pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)(
utts_r[i].squeeze().cpu().numpy(),
utts_g[i].squeeze().cpu().numpy(),
h.sample_rate)
for i in range(len(utts_r)))
pesq_score = np.mean(pesq_score)
return pesq_score
def eval_pesq(clean_utt, esti_utt, sr):
try:
pesq_score = pesq(sr, clean_utt, esti_utt)
except:
pesq_score = -1
return pesq_score
def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
hann_window = torch.hann_window(win_size).to(y.device)
stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
center=center, pad_mode='reflect', normalized=False, return_complex=True)
stft_spec = torch.view_as_real(stft_spec)
mag = torch.sqrt(stft_spec.pow(2).sum(-1) + 1e-9)
pha = torch.atan2(stft_spec[:, :, :, 1] + 1e-10, stft_spec[:, :, :, 0] + 1e-5)
# Magnitude Compression
mag = torch.pow(mag, compress_factor)
com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1)
return mag, pha, com
def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
# Magnitude Decompression
mag = torch.pow(mag, (1.0/compress_factor))
com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha))
hann_window = torch.hann_window(win_size).to(com.device)
wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
return wav
class LearnableSigmoid1d(nn.Module):
def __init__(self, in_features, beta=1):
super().__init__()
self.beta = beta
self.slope = nn.Parameter(torch.ones(in_features))
self.slope.requiresGrad = True
def forward(self, x):
# x shape: [batch_size, time_steps, spec_bins]
return self.beta * torch.sigmoid(self.slope * x)
class LearnableSigmoid2d(nn.Module):
def __init__(self, in_features, beta=1):
super().__init__()
self.beta = beta
self.slope = nn.Parameter(torch.ones(in_features, 1))
self.slope.requiresGrad = True
def forward(self, x):
return self.beta * torch.sigmoid(self.slope * x)
def main():
learnable_sigmoid = LearnableSigmoid1d(201)
a = torch.randn(4, 100, 201)
result = learnable_sigmoid.forward(a)
print(result.shape)
return
if __name__ == '__main__':
main()