tobiasc's picture
Initial commit
ad16788
raw
history blame
4.61 kB
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy
import torch
import torch.nn as nn
from torch_complex.tensor import ComplexTensor
from espnet.nets.pytorch_backend.frontends.dnn_beamformer import DNN_Beamformer
from espnet.nets.pytorch_backend.frontends.dnn_wpe import DNN_WPE
class Frontend(nn.Module):
def __init__(
self,
idim: int,
# WPE options
use_wpe: bool = False,
wtype: str = "blstmp",
wlayers: int = 3,
wunits: int = 300,
wprojs: int = 320,
wdropout_rate: float = 0.0,
taps: int = 5,
delay: int = 3,
use_dnn_mask_for_wpe: bool = True,
# Beamformer options
use_beamformer: bool = False,
btype: str = "blstmp",
blayers: int = 3,
bunits: int = 300,
bprojs: int = 320,
bnmask: int = 2,
badim: int = 320,
ref_channel: int = -1,
bdropout_rate=0.0,
):
super().__init__()
self.use_beamformer = use_beamformer
self.use_wpe = use_wpe
self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
# use frontend for all the data,
# e.g. in the case of multi-speaker speech separation
self.use_frontend_for_all = bnmask > 2
if self.use_wpe:
if self.use_dnn_mask_for_wpe:
# Use DNN for power estimation
# (Not observed significant gains)
iterations = 1
else:
# Performing as conventional WPE, without DNN Estimator
iterations = 2
self.wpe = DNN_WPE(
wtype=wtype,
widim=idim,
wunits=wunits,
wprojs=wprojs,
wlayers=wlayers,
taps=taps,
delay=delay,
dropout_rate=wdropout_rate,
iterations=iterations,
use_dnn_mask=use_dnn_mask_for_wpe,
)
else:
self.wpe = None
if self.use_beamformer:
self.beamformer = DNN_Beamformer(
btype=btype,
bidim=idim,
bunits=bunits,
bprojs=bprojs,
blayers=blayers,
bnmask=bnmask,
dropout_rate=bdropout_rate,
badim=badim,
ref_channel=ref_channel,
)
else:
self.beamformer = None
def forward(
self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
assert len(x) == len(ilens), (len(x), len(ilens))
# (B, T, F) or (B, T, C, F)
if x.dim() not in (3, 4):
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
if not torch.is_tensor(ilens):
ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
mask = None
h = x
if h.dim() == 4:
if self.training:
choices = [(False, False)] if not self.use_frontend_for_all else []
if self.use_wpe:
choices.append((True, False))
if self.use_beamformer:
choices.append((False, True))
use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
else:
use_wpe = self.use_wpe
use_beamformer = self.use_beamformer
# 1. WPE
if use_wpe:
# h: (B, T, C, F) -> h: (B, T, C, F)
h, ilens, mask = self.wpe(h, ilens)
# 2. Beamformer
if use_beamformer:
# h: (B, T, C, F) -> h: (B, T, F)
h, ilens, mask = self.beamformer(h, ilens)
return h, ilens, mask
def frontend_for(args, idim):
return Frontend(
idim=idim,
# WPE options
use_wpe=args.use_wpe,
wtype=args.wtype,
wlayers=args.wlayers,
wunits=args.wunits,
wprojs=args.wprojs,
wdropout_rate=args.wdropout_rate,
taps=args.wpe_taps,
delay=args.wpe_delay,
use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
# Beamformer options
use_beamformer=args.use_beamformer,
btype=args.btype,
blayers=args.blayers,
bunits=args.bunits,
bprojs=args.bprojs,
bnmask=args.bnmask,
badim=args.badim,
ref_channel=args.ref_channel,
bdropout_rate=args.bdropout_rate,
)