|
from collections import OrderedDict |
|
from typing import List |
|
from typing import Tuple |
|
|
|
import torch |
|
from torch_complex.tensor import ComplexTensor |
|
|
|
from espnet2.enh.layers.dnn_beamformer import DNN_Beamformer |
|
from espnet2.enh.layers.dnn_wpe import DNN_WPE |
|
from espnet2.enh.separator.abs_separator import AbsSeparator |
|
|
|
|
|
class NeuralBeamformer(AbsSeparator): |
|
def __init__( |
|
self, |
|
input_dim: int, |
|
num_spk: int = 1, |
|
loss_type: str = "mask_mse", |
|
|
|
use_wpe: bool = False, |
|
wnet_type: str = "blstmp", |
|
wlayers: int = 3, |
|
wunits: int = 300, |
|
wprojs: int = 320, |
|
wdropout_rate: float = 0.0, |
|
taps: int = 5, |
|
delay: int = 3, |
|
use_dnn_mask_for_wpe: bool = True, |
|
wnonlinear: str = "crelu", |
|
multi_source_wpe: bool = True, |
|
wnormalization: bool = False, |
|
|
|
use_beamformer: bool = True, |
|
bnet_type: str = "blstmp", |
|
blayers: int = 3, |
|
bunits: int = 300, |
|
bprojs: int = 320, |
|
badim: int = 320, |
|
ref_channel: int = -1, |
|
use_noise_mask: bool = True, |
|
bnonlinear: str = "sigmoid", |
|
beamformer_type: str = "mvdr_souden", |
|
rtf_iterations: int = 2, |
|
bdropout_rate: float = 0.0, |
|
shared_power: bool = True, |
|
|
|
diagonal_loading: bool = True, |
|
diag_eps_wpe: float = 1e-7, |
|
diag_eps_bf: float = 1e-7, |
|
mask_flooring: bool = False, |
|
flooring_thres_wpe: float = 1e-6, |
|
flooring_thres_bf: float = 1e-6, |
|
use_torch_solver: bool = True, |
|
): |
|
super().__init__() |
|
|
|
self._num_spk = num_spk |
|
self.loss_type = loss_type |
|
if loss_type not in ("mask_mse", "spectrum", "spectrum_log", "magnitude"): |
|
raise ValueError("Unsupported loss type: %s" % loss_type) |
|
|
|
self.use_beamformer = use_beamformer |
|
self.use_wpe = use_wpe |
|
|
|
if self.use_wpe: |
|
if use_dnn_mask_for_wpe: |
|
|
|
iterations = 1 |
|
else: |
|
|
|
iterations = 2 |
|
|
|
self.wpe = DNN_WPE( |
|
wtype=wnet_type, |
|
widim=input_dim, |
|
wlayers=wlayers, |
|
wunits=wunits, |
|
wprojs=wprojs, |
|
dropout_rate=wdropout_rate, |
|
taps=taps, |
|
delay=delay, |
|
use_dnn_mask=use_dnn_mask_for_wpe, |
|
nmask=1 if multi_source_wpe else num_spk, |
|
nonlinear=wnonlinear, |
|
iterations=iterations, |
|
normalization=wnormalization, |
|
diagonal_loading=diagonal_loading, |
|
diag_eps=diag_eps_wpe, |
|
mask_flooring=mask_flooring, |
|
flooring_thres=flooring_thres_wpe, |
|
use_torch_solver=use_torch_solver, |
|
) |
|
else: |
|
self.wpe = None |
|
|
|
self.ref_channel = ref_channel |
|
if self.use_beamformer: |
|
self.beamformer = DNN_Beamformer( |
|
bidim=input_dim, |
|
btype=bnet_type, |
|
blayers=blayers, |
|
bunits=bunits, |
|
bprojs=bprojs, |
|
num_spk=num_spk, |
|
use_noise_mask=use_noise_mask, |
|
nonlinear=bnonlinear, |
|
dropout_rate=bdropout_rate, |
|
badim=badim, |
|
ref_channel=ref_channel, |
|
beamformer_type=beamformer_type, |
|
rtf_iterations=rtf_iterations, |
|
btaps=taps, |
|
bdelay=delay, |
|
diagonal_loading=diagonal_loading, |
|
diag_eps=diag_eps_bf, |
|
mask_flooring=mask_flooring, |
|
flooring_thres=flooring_thres_bf, |
|
use_torch_solver=use_torch_solver, |
|
) |
|
else: |
|
self.beamformer = None |
|
|
|
|
|
self.shared_power = shared_power and use_wpe |
|
|
|
def forward( |
|
self, input: ComplexTensor, ilens: torch.Tensor |
|
) -> Tuple[List[ComplexTensor], torch.Tensor, OrderedDict]: |
|
"""Forward. |
|
|
|
Args: |
|
input (ComplexTensor): mixed speech [Batch, Frames, Channel, Freq] |
|
ilens (torch.Tensor): input lengths [Batch] |
|
|
|
Returns: |
|
enhanced speech (single-channel): List[ComplexTensor] |
|
output lengths |
|
other predcited data: OrderedDict[ |
|
'dereverb1': ComplexTensor(Batch, Frames, Channel, Freq), |
|
'mask_dereverb1': torch.Tensor(Batch, Frames, Channel, Freq), |
|
'mask_noise1': torch.Tensor(Batch, Frames, Channel, Freq), |
|
'mask_spk1': torch.Tensor(Batch, Frames, Channel, Freq), |
|
'mask_spk2': torch.Tensor(Batch, Frames, Channel, Freq), |
|
... |
|
'mask_spkn': torch.Tensor(Batch, Frames, Channel, Freq), |
|
] |
|
""" |
|
|
|
assert input.dim() in (3, 4), input.dim() |
|
enhanced = input |
|
others = OrderedDict() |
|
|
|
if ( |
|
self.training |
|
and self.loss_type is not None |
|
and self.loss_type.startswith("mask") |
|
): |
|
|
|
if self.use_wpe: |
|
if input.dim() == 3: |
|
mask_w, ilens = self.wpe.predict_mask(input.unsqueeze(-2), ilens) |
|
mask_w = mask_w.squeeze(-2) |
|
elif input.dim() == 4: |
|
mask_w, ilens = self.wpe.predict_mask(input, ilens) |
|
|
|
if mask_w is not None: |
|
if isinstance(enhanced, list): |
|
|
|
for spk in range(self.num_spk): |
|
others["mask_dereverb{}".format(spk + 1)] = mask_w[spk] |
|
else: |
|
|
|
others["mask_dereverb1"] = mask_w |
|
|
|
if self.use_beamformer and input.dim() == 4: |
|
others_b, ilens = self.beamformer.predict_mask(input, ilens) |
|
for spk in range(self.num_spk): |
|
others["mask_spk{}".format(spk + 1)] = others_b[spk] |
|
if len(others_b) > self.num_spk: |
|
others["mask_noise1"] = others_b[self.num_spk] |
|
|
|
return None, ilens, others |
|
|
|
else: |
|
powers = None |
|
|
|
if input.dim() == 3: |
|
|
|
if self.use_wpe: |
|
enhanced, ilens, mask_w, powers = self.wpe( |
|
input.unsqueeze(-2), ilens |
|
) |
|
if isinstance(enhanced, list): |
|
|
|
enhanced = [enh.squeeze(-2) for enh in enhanced] |
|
if mask_w is not None: |
|
for spk in range(self.num_spk): |
|
key = "dereverb{}".format(spk + 1) |
|
others[key] = enhanced[spk] |
|
others["mask_" + key] = mask_w[spk].squeeze(-2) |
|
else: |
|
|
|
enhanced = enhanced.squeeze(-2) |
|
if mask_w is not None: |
|
others["dereverb1"] = enhanced |
|
others["mask_dereverb1"] = mask_w.squeeze(-2) |
|
else: |
|
|
|
|
|
if self.use_wpe: |
|
enhanced, ilens, mask_w, powers = self.wpe(input, ilens) |
|
if mask_w is not None: |
|
if isinstance(enhanced, list): |
|
|
|
for spk in range(self.num_spk): |
|
key = "dereverb{}".format(spk + 1) |
|
others[key] = enhanced[spk] |
|
others["mask_" + key] = mask_w[spk] |
|
else: |
|
|
|
others["dereverb1"] = enhanced |
|
others["mask_dereverb1"] = mask_w.squeeze(-2) |
|
|
|
|
|
if self.use_beamformer: |
|
if ( |
|
not self.beamformer.beamformer_type.startswith("wmpdr") |
|
or not self.beamformer.beamformer_type.startswith("wpd") |
|
or not self.shared_power |
|
or (self.wpe.nmask == 1 and self.num_spk > 1) |
|
): |
|
powers = None |
|
|
|
|
|
if isinstance(enhanced, list): |
|
|
|
raise NotImplementedError( |
|
"Single-source WPE is not supported with beamformer " |
|
"in multi-speaker cases." |
|
) |
|
else: |
|
|
|
enhanced, ilens, others_b = self.beamformer( |
|
enhanced, ilens, powers=powers |
|
) |
|
for spk in range(self.num_spk): |
|
others["mask_spk{}".format(spk + 1)] = others_b[spk] |
|
if len(others_b) > self.num_spk: |
|
others["mask_noise1"] = others_b[self.num_spk] |
|
|
|
if not isinstance(enhanced, list): |
|
enhanced = [enhanced] |
|
|
|
return enhanced, ilens, others |
|
|
|
@property |
|
def num_spk(self): |
|
return self._num_spk |
|
|