HoneyTian's picture
first commit
bd94e77
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
# From torchaudio
def _compute_mat_trace(input: torch.Tensor, dim1: int = -2, dim2: int = -1) -> torch.Tensor:
r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
Args:
input (torch.Tensor): Tensor of dimension `(..., channel, channel)`
dim1 (int, optional): the first dimension of the diagonal matrix
(Default: -1)
dim2 (int, optional): the second dimension of the diagonal matrix
(Default: -2)
Returns:
Tensor: trace of the input Tensor
"""
assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
assert (
input.shape[dim1] == input.shape[dim2]
), "The size of ``dim1`` and ``dim2`` must be the same."
input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
return input.sum(dim=-1)
def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
"""Perform Tikhonov regularization (only modifying real part).
Args:
mat (torch.Tensor): input matrix (..., channel, channel)
reg (float, optional): regularization factor (Default: 1e-8)
eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``)
Returns:
Tensor: regularized matrix (..., channel, channel)
"""
# Add eps
C = mat.size(-1)
eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
epsilon = _compute_mat_trace(mat).real[..., None, None] * reg
# in case that correlation_matrix is all-zero
epsilon = epsilon + eps
mat = mat + epsilon * eye[..., :, :]
return mat
class MultiFrameModule(nn.Module):
"""
Multi-frame speech enhancement modules.
Signal model and notation:
Noisy: `x = s + n`
Enhanced: `y = f(x)`
Objective: `min ||s - y||`
PSD: Power spectral density, notated eg. as `Rxx` for noisy PSD.
IFC: Inter-frame correlation vector: PSD*u, u: selection vector. Notated as `rxx`
RTF: Relative transfere function, also called steering vector.
"""
def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, real: bool = False):
"""
Multi-Frame filtering module.
:param num_freqs: int. Number of frequency bins used for filtering.
:param frame_size: int. Frame size in FD domain.
:param lookahead: int. Lookahead, may be used to select the output time step.
Note: This module does not add additional padding according to lookahead!
:param real:
"""
super().__init__()
self.num_freqs = num_freqs
self.frame_size = frame_size
self.real = real
if real:
self.pad = nn.ConstantPad3d((0, 0, 0, 0, frame_size - 1 - lookahead, lookahead), 0.0)
else:
self.pad = nn.ConstantPad2d((0, 0, frame_size - 1 - lookahead, lookahead), 0.0)
self.need_unfold = frame_size > 1
self.lookahead = lookahead
def spec_unfold_real(self, spec: torch.Tensor):
if self.need_unfold:
spec = self.pad(spec).unfold(-3, self.frame_size, 1)
return spec.permute(0, 1, 5, 2, 3, 4)
# return as_windowed(self.pad(spec), self.frame_size, 1, dim=-3)
return spec.unsqueeze(-1)
def spec_unfold(self, spec: torch.Tensor):
"""Pads and unfolds the spectrogram according to frame_size.
Args:
spec (complex Tensor): Spectrogram of shape [B, C, T, F]
Returns:
spec (Tensor): Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size.
"""
if self.need_unfold:
return self.pad(spec).unfold(2, self.frame_size, 1)
return spec.unsqueeze(-1)
@staticmethod
def solve(Rxx, rss, diag_eps: float = 1e-8, eps: float = 1e-7) -> torch.Tensor:
return torch.einsum(
"...nm,...m->...n", torch.inverse(_tik_reg(Rxx, diag_eps, eps)), rss
) # [T, F, N]
@staticmethod
def apply_coefs(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
# spec: [B, C, T, F, N]
# coefs: [B, C, T, F, N]
return torch.einsum("...n,...n->...", spec, coefs)
class DF(MultiFrameModule):
"""Deep Filtering."""
def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, conj: bool = False):
super().__init__(num_freqs, frame_size, lookahead)
self.conj: bool = conj
def forward(self, spec: torch.Tensor, coefs: torch.Tensor):
spec_u = self.spec_unfold(torch.view_as_complex(spec))
coefs = torch.view_as_complex(coefs)
spec_f = spec_u.narrow(-2, 0, self.num_freqs)
coefs = coefs.view(coefs.shape[0], -1, self.frame_size, *coefs.shape[2:])
if self.conj:
coefs = coefs.conj()
spec_f = self.df(spec_f, coefs)
if self.training:
spec = spec.clone()
spec[..., : self.num_freqs, :] = torch.view_as_real(spec_f)
return spec
@staticmethod
def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
"""
Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
:param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N].
:param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F].
:return: (complex Tensor). Spectrogram of shape [B, C, T, F].
"""
return torch.einsum("...tfn,...ntf->...tf", spec, coefs)
if __name__ == '__main__':
pass