conex / espnet2 /enh /separator /rnn_separator.py
tobiasc's picture
Initial commit
ad16788
raw
history blame
3.24 kB
from collections import OrderedDict
from typing import List
from typing import Tuple
from typing import Union
import torch
from torch_complex.tensor import ComplexTensor
from espnet.nets.pytorch_backend.rnn.encoders import RNN
from espnet2.enh.separator.abs_separator import AbsSeparator
class RNNSeparator(AbsSeparator):
def __init__(
self,
input_dim: int,
rnn_type: str = "blstm",
num_spk: int = 2,
nonlinear: str = "sigmoid",
layer: int = 3,
unit: int = 512,
dropout: float = 0.0,
):
"""RNN Separator
Args:
input_dim: input feature dimension
rnn_type: string, select from 'blstm', 'lstm' etc.
bidirectional: bool, whether the inter-chunk RNN layers are bidirectional.
num_spk: number of speakers
nonlinear: the nonlinear function for mask estimation,
select from 'relu', 'tanh', 'sigmoid'
layer: int, number of stacked RNN layers. Default is 3.
unit: int, dimension of the hidden state.
dropout: float, dropout ratio. Default is 0.
"""
super().__init__()
self._num_spk = num_spk
self.rnn = RNN(
idim=input_dim,
elayers=layer,
cdim=unit,
hdim=unit,
dropout=dropout,
typ=rnn_type,
)
self.linear = torch.nn.ModuleList(
[torch.nn.Linear(unit, input_dim) for _ in range(self.num_spk)]
)
if nonlinear not in ("sigmoid", "relu", "tanh"):
raise ValueError("Not supporting nonlinear={}".format(nonlinear))
self.nonlinear = {
"sigmoid": torch.nn.Sigmoid(),
"relu": torch.nn.ReLU(),
"tanh": torch.nn.Tanh(),
}[nonlinear]
def forward(
self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
"""Forward.
Args:
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
ilens (torch.Tensor): input lengths [Batch]
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
ilens (torch.Tensor): (B,)
others predicted data, e.g. masks: OrderedDict[
'mask_spk1': torch.Tensor(Batch, Frames, Freq),
'mask_spk2': torch.Tensor(Batch, Frames, Freq),
...
'mask_spkn': torch.Tensor(Batch, Frames, Freq),
]
"""
# if complex spectrum,
if isinstance(input, ComplexTensor):
feature = abs(input)
else:
feature = input
x, ilens, _ = self.rnn(feature, ilens)
masks = []
for linear in self.linear:
y = linear(x)
y = self.nonlinear(y)
masks.append(y)
masked = [input * m for m in masks]
others = OrderedDict(
zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks)
)
return masked, ilens, others
@property
def num_spk(self):
return self._num_spk