File size: 6,480 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from collections import OrderedDict
from typing import List
from typing import Tuple
from typing import Union
import torch
from torch_complex.tensor import ComplexTensor
from espnet.nets.pytorch_backend.conformer.encoder import (
Encoder as ConformerEncoder, # noqa: H301
)
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
from espnet2.enh.separator.abs_separator import AbsSeparator
class ConformerSeparator(AbsSeparator):
def __init__(
self,
input_dim: int,
num_spk: int = 2,
adim: int = 384,
aheads: int = 4,
layers: int = 6,
linear_units: int = 1536,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
normalize_before: bool = False,
concat_after: bool = False,
dropout_rate: float = 0.1,
input_layer: str = "linear",
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.1,
nonlinear: str = "relu",
conformer_pos_enc_layer_type: str = "rel_pos",
conformer_self_attn_layer_type: str = "rel_selfattn",
conformer_activation_type: str = "swish",
use_macaron_style_in_conformer: bool = True,
use_cnn_in_conformer: bool = True,
conformer_enc_kernel_size: int = 7,
padding_idx: int = -1,
):
"""Conformer separator.
Args:
input_dim: input feature dimension
num_spk: number of speakers
adim (int): Dimention of attention.
aheads (int): The number of heads of multi head attention.
linear_units (int): The number of units of position-wise feed forward.
layers (int): The number of transformer blocks.
dropout_rate (float): Dropout rate.
input_layer (Union[str, torch.nn.Module]): Input layer type.
attention_dropout_rate (float): Dropout rate in attention.
positional_dropout_rate (float): Dropout rate after adding
positional encoding.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
conformer_pos_enc_layer_type(str): Encoder positional encoding layer type.
conformer_self_attn_layer_type (str): Encoder attention layer type.
conformer_activation_type(str): Encoder activation function type.
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of
positionwise conv1d layer.
use_macaron_style_in_conformer (bool): Whether to use macaron style for
positionwise layer.
use_cnn_in_conformer (bool): Whether to use convolution module.
conformer_enc_kernel_size(int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
nonlinear: the nonlinear function for mask estimation,
select from 'relu', 'tanh', 'sigmoid'
"""
super().__init__()
self._num_spk = num_spk
self.conformer = ConformerEncoder(
idim=input_dim,
attention_dim=adim,
attention_heads=aheads,
linear_units=linear_units,
num_blocks=layers,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
input_layer=input_layer,
normalize_before=normalize_before,
concat_after=concat_after,
positionwise_layer_type=positionwise_layer_type,
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
macaron_style=use_macaron_style_in_conformer,
pos_enc_layer_type=conformer_pos_enc_layer_type,
selfattention_layer_type=conformer_self_attn_layer_type,
activation_type=conformer_activation_type,
use_cnn_module=use_cnn_in_conformer,
cnn_module_kernel=conformer_enc_kernel_size,
padding_idx=padding_idx,
)
self.linear = torch.nn.ModuleList(
[torch.nn.Linear(adim, input_dim) for _ in range(self.num_spk)]
)
if nonlinear not in ("sigmoid", "relu", "tanh"):
raise ValueError("Not supporting nonlinear={}".format(nonlinear))
self.nonlinear = {
"sigmoid": torch.nn.Sigmoid(),
"relu": torch.nn.ReLU(),
"tanh": torch.nn.Tanh(),
}[nonlinear]
def forward(
self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
"""Forward.
Args:
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
ilens (torch.Tensor): input lengths [Batch]
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
ilens (torch.Tensor): (B,)
others predicted data, e.g. masks: OrderedDict[
'mask_spk1': torch.Tensor(Batch, Frames, Freq),
'mask_spk2': torch.Tensor(Batch, Frames, Freq),
...
'mask_spkn': torch.Tensor(Batch, Frames, Freq),
]
"""
# if complex spectrum,
if isinstance(input, ComplexTensor):
feature = abs(input)
else:
feature = input
# prepare pad_mask for transformer
pad_mask = make_non_pad_mask(ilens).unsqueeze(1).to(feature.device)
x, ilens = self.conformer(feature, pad_mask)
masks = []
for linear in self.linear:
y = linear(x)
y = self.nonlinear(y)
masks.append(y)
masked = [input * m for m in masks]
others = OrderedDict(
zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks)
)
return masked, ilens, others
@property
def num_spk(self):
return self._num_spk
|