File size: 4,611 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy
import torch
import torch.nn as nn
from torch_complex.tensor import ComplexTensor
from espnet.nets.pytorch_backend.frontends.dnn_beamformer import DNN_Beamformer
from espnet.nets.pytorch_backend.frontends.dnn_wpe import DNN_WPE
class Frontend(nn.Module):
def __init__(
self,
idim: int,
# WPE options
use_wpe: bool = False,
wtype: str = "blstmp",
wlayers: int = 3,
wunits: int = 300,
wprojs: int = 320,
wdropout_rate: float = 0.0,
taps: int = 5,
delay: int = 3,
use_dnn_mask_for_wpe: bool = True,
# Beamformer options
use_beamformer: bool = False,
btype: str = "blstmp",
blayers: int = 3,
bunits: int = 300,
bprojs: int = 320,
bnmask: int = 2,
badim: int = 320,
ref_channel: int = -1,
bdropout_rate=0.0,
):
super().__init__()
self.use_beamformer = use_beamformer
self.use_wpe = use_wpe
self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
# use frontend for all the data,
# e.g. in the case of multi-speaker speech separation
self.use_frontend_for_all = bnmask > 2
if self.use_wpe:
if self.use_dnn_mask_for_wpe:
# Use DNN for power estimation
# (Not observed significant gains)
iterations = 1
else:
# Performing as conventional WPE, without DNN Estimator
iterations = 2
self.wpe = DNN_WPE(
wtype=wtype,
widim=idim,
wunits=wunits,
wprojs=wprojs,
wlayers=wlayers,
taps=taps,
delay=delay,
dropout_rate=wdropout_rate,
iterations=iterations,
use_dnn_mask=use_dnn_mask_for_wpe,
)
else:
self.wpe = None
if self.use_beamformer:
self.beamformer = DNN_Beamformer(
btype=btype,
bidim=idim,
bunits=bunits,
bprojs=bprojs,
blayers=blayers,
bnmask=bnmask,
dropout_rate=bdropout_rate,
badim=badim,
ref_channel=ref_channel,
)
else:
self.beamformer = None
def forward(
self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
assert len(x) == len(ilens), (len(x), len(ilens))
# (B, T, F) or (B, T, C, F)
if x.dim() not in (3, 4):
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
if not torch.is_tensor(ilens):
ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
mask = None
h = x
if h.dim() == 4:
if self.training:
choices = [(False, False)] if not self.use_frontend_for_all else []
if self.use_wpe:
choices.append((True, False))
if self.use_beamformer:
choices.append((False, True))
use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
else:
use_wpe = self.use_wpe
use_beamformer = self.use_beamformer
# 1. WPE
if use_wpe:
# h: (B, T, C, F) -> h: (B, T, C, F)
h, ilens, mask = self.wpe(h, ilens)
# 2. Beamformer
if use_beamformer:
# h: (B, T, C, F) -> h: (B, T, F)
h, ilens, mask = self.beamformer(h, ilens)
return h, ilens, mask
def frontend_for(args, idim):
return Frontend(
idim=idim,
# WPE options
use_wpe=args.use_wpe,
wtype=args.wtype,
wlayers=args.wlayers,
wunits=args.wunits,
wprojs=args.wprojs,
wdropout_rate=args.wdropout_rate,
taps=args.wpe_taps,
delay=args.wpe_delay,
use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
# Beamformer options
use_beamformer=args.use_beamformer,
btype=args.btype,
blayers=args.blayers,
bunits=args.bunits,
bprojs=args.bprojs,
bnmask=args.bnmask,
badim=args.badim,
ref_channel=args.ref_channel,
bdropout_rate=args.bdropout_rate,
)
|