#!/usr/bin/python3 # -*- coding: utf-8 -*- import torch import torch.nn as nn # From torchaudio def _compute_mat_trace(input: torch.Tensor, dim1: int = -2, dim2: int = -1) -> torch.Tensor: r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions. Args: input (torch.Tensor): Tensor of dimension `(..., channel, channel)` dim1 (int, optional): the first dimension of the diagonal matrix (Default: -1) dim2 (int, optional): the second dimension of the diagonal matrix (Default: -2) Returns: Tensor: trace of the input Tensor """ assert input.ndim >= 2, "The dimension of the tensor must be at least 2." assert ( input.shape[dim1] == input.shape[dim2] ), "The size of ``dim1`` and ``dim2`` must be the same." input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2) return input.sum(dim=-1) def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor: """Perform Tikhonov regularization (only modifying real part). Args: mat (torch.Tensor): input matrix (..., channel, channel) reg (float, optional): regularization factor (Default: 1e-8) eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``) Returns: Tensor: regularized matrix (..., channel, channel) """ # Add eps C = mat.size(-1) eye = torch.eye(C, dtype=mat.dtype, device=mat.device) epsilon = _compute_mat_trace(mat).real[..., None, None] * reg # in case that correlation_matrix is all-zero epsilon = epsilon + eps mat = mat + epsilon * eye[..., :, :] return mat class MultiFrameModule(nn.Module): """ Multi-frame speech enhancement modules. Signal model and notation: Noisy: `x = s + n` Enhanced: `y = f(x)` Objective: `min ||s - y||` PSD: Power spectral density, notated eg. as `Rxx` for noisy PSD. IFC: Inter-frame correlation vector: PSD*u, u: selection vector. Notated as `rxx` RTF: Relative transfere function, also called steering vector. """ def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, real: bool = False): """ Multi-Frame filtering module. :param num_freqs: int. Number of frequency bins used for filtering. :param frame_size: int. Frame size in FD domain. :param lookahead: int. Lookahead, may be used to select the output time step. Note: This module does not add additional padding according to lookahead! :param real: """ super().__init__() self.num_freqs = num_freqs self.frame_size = frame_size self.real = real if real: self.pad = nn.ConstantPad3d((0, 0, 0, 0, frame_size - 1 - lookahead, lookahead), 0.0) else: self.pad = nn.ConstantPad2d((0, 0, frame_size - 1 - lookahead, lookahead), 0.0) self.need_unfold = frame_size > 1 self.lookahead = lookahead def spec_unfold_real(self, spec: torch.Tensor): if self.need_unfold: spec = self.pad(spec).unfold(-3, self.frame_size, 1) return spec.permute(0, 1, 5, 2, 3, 4) # return as_windowed(self.pad(spec), self.frame_size, 1, dim=-3) return spec.unsqueeze(-1) def spec_unfold(self, spec: torch.Tensor): """Pads and unfolds the spectrogram according to frame_size. Args: spec (complex Tensor): Spectrogram of shape [B, C, T, F] Returns: spec (Tensor): Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size. """ if self.need_unfold: return self.pad(spec).unfold(2, self.frame_size, 1) return spec.unsqueeze(-1) @staticmethod def solve(Rxx, rss, diag_eps: float = 1e-8, eps: float = 1e-7) -> torch.Tensor: return torch.einsum( "...nm,...m->...n", torch.inverse(_tik_reg(Rxx, diag_eps, eps)), rss ) # [T, F, N] @staticmethod def apply_coefs(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor: # spec: [B, C, T, F, N] # coefs: [B, C, T, F, N] return torch.einsum("...n,...n->...", spec, coefs) class DF(MultiFrameModule): """Deep Filtering.""" def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, conj: bool = False): super().__init__(num_freqs, frame_size, lookahead) self.conj: bool = conj def forward(self, spec: torch.Tensor, coefs: torch.Tensor): spec_u = self.spec_unfold(torch.view_as_complex(spec)) coefs = torch.view_as_complex(coefs) spec_f = spec_u.narrow(-2, 0, self.num_freqs) coefs = coefs.view(coefs.shape[0], -1, self.frame_size, *coefs.shape[2:]) if self.conj: coefs = coefs.conj() spec_f = self.df(spec_f, coefs) if self.training: spec = spec.clone() spec[..., : self.num_freqs, :] = torch.view_as_real(spec_f) return spec @staticmethod def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor: """ Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram. :param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N]. :param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F]. :return: (complex Tensor). Spectrogram of shape [B, C, T, F]. """ return torch.einsum("...tfn,...ntf->...tf", spec, coefs) if __name__ == '__main__': pass