Spaces:
Running
Running
File size: 5,589 Bytes
bd94e77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
# From torchaudio
def _compute_mat_trace(input: torch.Tensor, dim1: int = -2, dim2: int = -1) -> torch.Tensor:
r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
Args:
input (torch.Tensor): Tensor of dimension `(..., channel, channel)`
dim1 (int, optional): the first dimension of the diagonal matrix
(Default: -1)
dim2 (int, optional): the second dimension of the diagonal matrix
(Default: -2)
Returns:
Tensor: trace of the input Tensor
"""
assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
assert (
input.shape[dim1] == input.shape[dim2]
), "The size of ``dim1`` and ``dim2`` must be the same."
input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
return input.sum(dim=-1)
def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
"""Perform Tikhonov regularization (only modifying real part).
Args:
mat (torch.Tensor): input matrix (..., channel, channel)
reg (float, optional): regularization factor (Default: 1e-8)
eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``)
Returns:
Tensor: regularized matrix (..., channel, channel)
"""
# Add eps
C = mat.size(-1)
eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
epsilon = _compute_mat_trace(mat).real[..., None, None] * reg
# in case that correlation_matrix is all-zero
epsilon = epsilon + eps
mat = mat + epsilon * eye[..., :, :]
return mat
class MultiFrameModule(nn.Module):
"""
Multi-frame speech enhancement modules.
Signal model and notation:
Noisy: `x = s + n`
Enhanced: `y = f(x)`
Objective: `min ||s - y||`
PSD: Power spectral density, notated eg. as `Rxx` for noisy PSD.
IFC: Inter-frame correlation vector: PSD*u, u: selection vector. Notated as `rxx`
RTF: Relative transfere function, also called steering vector.
"""
def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, real: bool = False):
"""
Multi-Frame filtering module.
:param num_freqs: int. Number of frequency bins used for filtering.
:param frame_size: int. Frame size in FD domain.
:param lookahead: int. Lookahead, may be used to select the output time step.
Note: This module does not add additional padding according to lookahead!
:param real:
"""
super().__init__()
self.num_freqs = num_freqs
self.frame_size = frame_size
self.real = real
if real:
self.pad = nn.ConstantPad3d((0, 0, 0, 0, frame_size - 1 - lookahead, lookahead), 0.0)
else:
self.pad = nn.ConstantPad2d((0, 0, frame_size - 1 - lookahead, lookahead), 0.0)
self.need_unfold = frame_size > 1
self.lookahead = lookahead
def spec_unfold_real(self, spec: torch.Tensor):
if self.need_unfold:
spec = self.pad(spec).unfold(-3, self.frame_size, 1)
return spec.permute(0, 1, 5, 2, 3, 4)
# return as_windowed(self.pad(spec), self.frame_size, 1, dim=-3)
return spec.unsqueeze(-1)
def spec_unfold(self, spec: torch.Tensor):
"""Pads and unfolds the spectrogram according to frame_size.
Args:
spec (complex Tensor): Spectrogram of shape [B, C, T, F]
Returns:
spec (Tensor): Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size.
"""
if self.need_unfold:
return self.pad(spec).unfold(2, self.frame_size, 1)
return spec.unsqueeze(-1)
@staticmethod
def solve(Rxx, rss, diag_eps: float = 1e-8, eps: float = 1e-7) -> torch.Tensor:
return torch.einsum(
"...nm,...m->...n", torch.inverse(_tik_reg(Rxx, diag_eps, eps)), rss
) # [T, F, N]
@staticmethod
def apply_coefs(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
# spec: [B, C, T, F, N]
# coefs: [B, C, T, F, N]
return torch.einsum("...n,...n->...", spec, coefs)
class DF(MultiFrameModule):
"""Deep Filtering."""
def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, conj: bool = False):
super().__init__(num_freqs, frame_size, lookahead)
self.conj: bool = conj
def forward(self, spec: torch.Tensor, coefs: torch.Tensor):
spec_u = self.spec_unfold(torch.view_as_complex(spec))
coefs = torch.view_as_complex(coefs)
spec_f = spec_u.narrow(-2, 0, self.num_freqs)
coefs = coefs.view(coefs.shape[0], -1, self.frame_size, *coefs.shape[2:])
if self.conj:
coefs = coefs.conj()
spec_f = self.df(spec_f, coefs)
if self.training:
spec = spec.clone()
spec[..., : self.num_freqs, :] = torch.view_as_real(spec_f)
return spec
@staticmethod
def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
"""
Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
:param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N].
:param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F].
:return: (complex Tensor). Spectrogram of shape [B, C, T, F].
"""
return torch.einsum("...tfn,...ntf->...tf", spec, coefs)
if __name__ == '__main__':
pass
|