|
from typing import Tuple |
|
|
|
from pytorch_wpe import wpe_one_iteration |
|
import torch |
|
from torch_complex.tensor import ComplexTensor |
|
|
|
from espnet.nets.pytorch_backend.frontends.mask_estimator import MaskEstimator |
|
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
|
|
|
|
class DNN_WPE(torch.nn.Module): |
|
def __init__( |
|
self, |
|
wtype: str = "blstmp", |
|
widim: int = 257, |
|
wlayers: int = 3, |
|
wunits: int = 300, |
|
wprojs: int = 320, |
|
dropout_rate: float = 0.0, |
|
taps: int = 5, |
|
delay: int = 3, |
|
use_dnn_mask: bool = True, |
|
iterations: int = 1, |
|
normalization: bool = False, |
|
): |
|
super().__init__() |
|
self.iterations = iterations |
|
self.taps = taps |
|
self.delay = delay |
|
|
|
self.normalization = normalization |
|
self.use_dnn_mask = use_dnn_mask |
|
|
|
self.inverse_power = True |
|
|
|
if self.use_dnn_mask: |
|
self.mask_est = MaskEstimator( |
|
wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1 |
|
) |
|
|
|
def forward( |
|
self, data: ComplexTensor, ilens: torch.LongTensor |
|
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: |
|
"""The forward function |
|
|
|
Notation: |
|
B: Batch |
|
C: Channel |
|
T: Time or Sequence length |
|
F: Freq or Some dimension of the feature vector |
|
|
|
Args: |
|
data: (B, C, T, F) |
|
ilens: (B,) |
|
Returns: |
|
data: (B, C, T, F) |
|
ilens: (B,) |
|
""" |
|
|
|
enhanced = data = data.permute(0, 3, 2, 1) |
|
mask = None |
|
|
|
for i in range(self.iterations): |
|
|
|
power = enhanced.real ** 2 + enhanced.imag ** 2 |
|
if i == 0 and self.use_dnn_mask: |
|
|
|
(mask,), _ = self.mask_est(enhanced, ilens) |
|
if self.normalization: |
|
|
|
mask = mask / mask.sum(dim=-1)[..., None] |
|
|
|
power = power * mask |
|
|
|
|
|
power = power.mean(dim=-2) |
|
|
|
|
|
enhanced = wpe_one_iteration( |
|
data.contiguous(), |
|
power, |
|
taps=self.taps, |
|
delay=self.delay, |
|
inverse_power=self.inverse_power, |
|
) |
|
|
|
enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0) |
|
|
|
|
|
enhanced = enhanced.permute(0, 3, 2, 1) |
|
if mask is not None: |
|
mask = mask.transpose(-1, -3) |
|
return enhanced, ilens, mask |
|
|