|
from collections import OrderedDict |
|
from typing import List |
|
from typing import Tuple |
|
from typing import Union |
|
|
|
import torch |
|
from torch_complex.tensor import ComplexTensor |
|
|
|
from espnet.nets.pytorch_backend.rnn.encoders import RNN |
|
from espnet2.enh.separator.abs_separator import AbsSeparator |
|
|
|
|
|
class RNNSeparator(AbsSeparator): |
|
def __init__( |
|
self, |
|
input_dim: int, |
|
rnn_type: str = "blstm", |
|
num_spk: int = 2, |
|
nonlinear: str = "sigmoid", |
|
layer: int = 3, |
|
unit: int = 512, |
|
dropout: float = 0.0, |
|
): |
|
"""RNN Separator |
|
|
|
Args: |
|
input_dim: input feature dimension |
|
rnn_type: string, select from 'blstm', 'lstm' etc. |
|
bidirectional: bool, whether the inter-chunk RNN layers are bidirectional. |
|
num_spk: number of speakers |
|
nonlinear: the nonlinear function for mask estimation, |
|
select from 'relu', 'tanh', 'sigmoid' |
|
layer: int, number of stacked RNN layers. Default is 3. |
|
unit: int, dimension of the hidden state. |
|
dropout: float, dropout ratio. Default is 0. |
|
""" |
|
super().__init__() |
|
|
|
self._num_spk = num_spk |
|
|
|
self.rnn = RNN( |
|
idim=input_dim, |
|
elayers=layer, |
|
cdim=unit, |
|
hdim=unit, |
|
dropout=dropout, |
|
typ=rnn_type, |
|
) |
|
|
|
self.linear = torch.nn.ModuleList( |
|
[torch.nn.Linear(unit, input_dim) for _ in range(self.num_spk)] |
|
) |
|
|
|
if nonlinear not in ("sigmoid", "relu", "tanh"): |
|
raise ValueError("Not supporting nonlinear={}".format(nonlinear)) |
|
|
|
self.nonlinear = { |
|
"sigmoid": torch.nn.Sigmoid(), |
|
"relu": torch.nn.ReLU(), |
|
"tanh": torch.nn.Tanh(), |
|
}[nonlinear] |
|
|
|
def forward( |
|
self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor |
|
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: |
|
"""Forward. |
|
|
|
Args: |
|
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] |
|
ilens (torch.Tensor): input lengths [Batch] |
|
|
|
Returns: |
|
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] |
|
ilens (torch.Tensor): (B,) |
|
others predicted data, e.g. masks: OrderedDict[ |
|
'mask_spk1': torch.Tensor(Batch, Frames, Freq), |
|
'mask_spk2': torch.Tensor(Batch, Frames, Freq), |
|
... |
|
'mask_spkn': torch.Tensor(Batch, Frames, Freq), |
|
] |
|
""" |
|
|
|
|
|
if isinstance(input, ComplexTensor): |
|
feature = abs(input) |
|
else: |
|
feature = input |
|
|
|
x, ilens, _ = self.rnn(feature, ilens) |
|
|
|
masks = [] |
|
|
|
for linear in self.linear: |
|
y = linear(x) |
|
y = self.nonlinear(y) |
|
masks.append(y) |
|
|
|
masked = [input * m for m in masks] |
|
|
|
others = OrderedDict( |
|
zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) |
|
) |
|
|
|
return masked, ilens, others |
|
|
|
@property |
|
def num_spk(self): |
|
return self._num_spk |
|
|