Spaces:
Runtime error
Runtime error
from typing import List | |
from typing import Optional | |
from typing import Tuple | |
from typing import Union | |
import numpy | |
import torch | |
import torch.nn as nn | |
from torch_complex.tensor import ComplexTensor | |
from funasr_detach.frontends.utils.dnn_beamformer import DNN_Beamformer | |
from funasr_detach.frontends.utils.dnn_wpe import DNN_WPE | |
class Frontend(nn.Module): | |
def __init__( | |
self, | |
idim: int, | |
# WPE options | |
use_wpe: bool = False, | |
wtype: str = "blstmp", | |
wlayers: int = 3, | |
wunits: int = 300, | |
wprojs: int = 320, | |
wdropout_rate: float = 0.0, | |
taps: int = 5, | |
delay: int = 3, | |
use_dnn_mask_for_wpe: bool = True, | |
# Beamformer options | |
use_beamformer: bool = False, | |
btype: str = "blstmp", | |
blayers: int = 3, | |
bunits: int = 300, | |
bprojs: int = 320, | |
bnmask: int = 2, | |
badim: int = 320, | |
ref_channel: int = -1, | |
bdropout_rate=0.0, | |
): | |
super().__init__() | |
self.use_beamformer = use_beamformer | |
self.use_wpe = use_wpe | |
self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe | |
# use frontend for all the data, | |
# e.g. in the case of multi-speaker speech separation | |
self.use_frontend_for_all = bnmask > 2 | |
if self.use_wpe: | |
if self.use_dnn_mask_for_wpe: | |
# Use DNN for power estimation | |
# (Not observed significant gains) | |
iterations = 1 | |
else: | |
# Performing as conventional WPE, without DNN Estimator | |
iterations = 2 | |
self.wpe = DNN_WPE( | |
wtype=wtype, | |
widim=idim, | |
wunits=wunits, | |
wprojs=wprojs, | |
wlayers=wlayers, | |
taps=taps, | |
delay=delay, | |
dropout_rate=wdropout_rate, | |
iterations=iterations, | |
use_dnn_mask=use_dnn_mask_for_wpe, | |
) | |
else: | |
self.wpe = None | |
if self.use_beamformer: | |
self.beamformer = DNN_Beamformer( | |
btype=btype, | |
bidim=idim, | |
bunits=bunits, | |
bprojs=bprojs, | |
blayers=blayers, | |
bnmask=bnmask, | |
dropout_rate=bdropout_rate, | |
badim=badim, | |
ref_channel=ref_channel, | |
) | |
else: | |
self.beamformer = None | |
def forward( | |
self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]] | |
) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]: | |
assert len(x) == len(ilens), (len(x), len(ilens)) | |
# (B, T, F) or (B, T, C, F) | |
if x.dim() not in (3, 4): | |
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}") | |
if not torch.is_tensor(ilens): | |
ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device) | |
mask = None | |
h = x | |
if h.dim() == 4: | |
if self.training: | |
choices = [(False, False)] if not self.use_frontend_for_all else [] | |
if self.use_wpe: | |
choices.append((True, False)) | |
if self.use_beamformer: | |
choices.append((False, True)) | |
use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))] | |
else: | |
use_wpe = self.use_wpe | |
use_beamformer = self.use_beamformer | |
# 1. WPE | |
if use_wpe: | |
# h: (B, T, C, F) -> h: (B, T, C, F) | |
h, ilens, mask = self.wpe(h, ilens) | |
# 2. Beamformer | |
if use_beamformer: | |
# h: (B, T, C, F) -> h: (B, T, F) | |
h, ilens, mask = self.beamformer(h, ilens) | |
return h, ilens, mask | |
def frontend_for(args, idim): | |
return Frontend( | |
idim=idim, | |
# WPE options | |
use_wpe=args.use_wpe, | |
wtype=args.wtype, | |
wlayers=args.wlayers, | |
wunits=args.wunits, | |
wprojs=args.wprojs, | |
wdropout_rate=args.wdropout_rate, | |
taps=args.wpe_taps, | |
delay=args.wpe_delay, | |
use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe, | |
# Beamformer options | |
use_beamformer=args.use_beamformer, | |
btype=args.btype, | |
blayers=args.blayers, | |
bunits=args.bunits, | |
bprojs=args.bprojs, | |
bnmask=args.bnmask, | |
badim=args.badim, | |
ref_channel=args.ref_channel, | |
bdropout_rate=args.bdropout_rate, | |
) | |