|
|
|
|
|
""" |
|
This script is used to construct End-to-End models of multi-speaker ASR. |
|
|
|
Copyright 2017 Johns Hopkins University (Shinji Watanabe) |
|
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
|
""" |
|
|
|
import argparse |
|
from itertools import groupby |
|
import logging |
|
import math |
|
import os |
|
import sys |
|
|
|
import editdistance |
|
import numpy as np |
|
import six |
|
import torch |
|
|
|
from espnet.nets.asr_interface import ASRInterface |
|
from espnet.nets.e2e_asr_common import get_vgg2l_odim |
|
from espnet.nets.e2e_asr_common import label_smoothing_dist |
|
from espnet.nets.pytorch_backend.ctc import ctc_for |
|
from espnet.nets.pytorch_backend.e2e_asr import E2E as E2EASR |
|
from espnet.nets.pytorch_backend.e2e_asr import Reporter |
|
from espnet.nets.pytorch_backend.frontends.feature_transform import ( |
|
feature_transform_for, |
|
) |
|
from espnet.nets.pytorch_backend.frontends.frontend import frontend_for |
|
from espnet.nets.pytorch_backend.initialization import lecun_normal_init_parameters |
|
from espnet.nets.pytorch_backend.initialization import set_forget_bias_to_one |
|
from espnet.nets.pytorch_backend.nets_utils import get_subsample |
|
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
from espnet.nets.pytorch_backend.nets_utils import pad_list |
|
from espnet.nets.pytorch_backend.nets_utils import to_device |
|
from espnet.nets.pytorch_backend.nets_utils import to_torch_tensor |
|
from espnet.nets.pytorch_backend.rnn.attentions import att_for |
|
from espnet.nets.pytorch_backend.rnn.decoders import decoder_for |
|
from espnet.nets.pytorch_backend.rnn.encoders import encoder_for as encoder_for_single |
|
from espnet.nets.pytorch_backend.rnn.encoders import RNNP |
|
from espnet.nets.pytorch_backend.rnn.encoders import VGG2L |
|
|
|
CTC_LOSS_THRESHOLD = 10000 |
|
|
|
|
|
class PIT(object): |
|
"""Permutation Invariant Training (PIT) module. |
|
|
|
:parameter int num_spkrs: number of speakers for PIT process (2 or 3) |
|
""" |
|
|
|
def __init__(self, num_spkrs): |
|
"""Initialize PIT module.""" |
|
self.num_spkrs = num_spkrs |
|
|
|
|
|
|
|
self.perm_choices = [] |
|
initial_seq = np.linspace(0, num_spkrs - 1, num_spkrs, dtype=np.int64) |
|
self.permutationDFS(initial_seq, 0) |
|
|
|
|
|
|
|
self.loss_perm_idx = np.linspace( |
|
0, num_spkrs * (num_spkrs - 1), num_spkrs, dtype=np.int64 |
|
).reshape(1, num_spkrs) |
|
self.loss_perm_idx = (self.loss_perm_idx + np.array(self.perm_choices)).tolist() |
|
|
|
def min_pit_sample(self, loss): |
|
"""Compute the PIT loss for each sample. |
|
|
|
:param 1-D torch.Tensor loss: list of losses for one sample, |
|
including [h1r1, h1r2, h2r1, h2r2] or |
|
[h1r1, h1r2, h1r3, h2r1, h2r2, h2r3, h3r1, h3r2, h3r3] |
|
:return minimum loss of best permutation |
|
:rtype torch.Tensor (1) |
|
:return the best permutation |
|
:rtype List: len=2 |
|
|
|
""" |
|
score_perms = ( |
|
torch.stack( |
|
[torch.sum(loss[loss_perm_idx]) for loss_perm_idx in self.loss_perm_idx] |
|
) |
|
/ self.num_spkrs |
|
) |
|
perm_loss, min_idx = torch.min(score_perms, 0) |
|
permutation = self.perm_choices[min_idx] |
|
return perm_loss, permutation |
|
|
|
def pit_process(self, losses): |
|
"""Compute the PIT loss for a batch. |
|
|
|
:param torch.Tensor losses: losses (B, 1|4|9) |
|
:return minimum losses of a batch with best permutation |
|
:rtype torch.Tensor (B) |
|
:return the best permutation |
|
:rtype torch.LongTensor (B, 1|2|3) |
|
|
|
""" |
|
bs = losses.size(0) |
|
ret = [self.min_pit_sample(losses[i]) for i in range(bs)] |
|
|
|
loss_perm = torch.stack([r[0] for r in ret], dim=0).to(losses.device) |
|
permutation = torch.tensor([r[1] for r in ret]).long().to(losses.device) |
|
return torch.mean(loss_perm), permutation |
|
|
|
def permutationDFS(self, source, start): |
|
"""Get permutations with DFS. |
|
|
|
The final result is all permutations of the 'source' sequence. |
|
e.g. [[1, 2], [2, 1]] or |
|
[[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 2, 1], [3, 1, 2]] |
|
|
|
:param np.ndarray source: (num_spkrs, 1), e.g. [1, 2, ..., N] |
|
:param int start: the start point to permute |
|
|
|
""" |
|
if start == len(source) - 1: |
|
self.perm_choices.append(source.tolist()) |
|
for i in range(start, len(source)): |
|
|
|
source[start], source[i] = source[i], source[start] |
|
self.permutationDFS(source, start + 1) |
|
|
|
source[start], source[i] = source[i], source[start] |
|
|
|
|
|
class E2E(ASRInterface, torch.nn.Module): |
|
"""E2E module. |
|
|
|
:param int idim: dimension of inputs |
|
:param int odim: dimension of outputs |
|
:param Namespace args: argument Namespace containing options |
|
""" |
|
|
|
@staticmethod |
|
def add_arguments(parser): |
|
"""Add arguments.""" |
|
E2EASR.encoder_add_arguments(parser) |
|
E2E.encoder_mix_add_arguments(parser) |
|
E2EASR.attention_add_arguments(parser) |
|
E2EASR.decoder_add_arguments(parser) |
|
return parser |
|
|
|
@staticmethod |
|
def encoder_mix_add_arguments(parser): |
|
"""Add arguments for multi-speaker encoder.""" |
|
group = parser.add_argument_group("E2E encoder setting for multi-speaker") |
|
|
|
group.add_argument( |
|
"--spa", |
|
action="store_true", |
|
help="Enable speaker parallel attention " |
|
"for multi-speaker speech recognition task.", |
|
) |
|
group.add_argument( |
|
"--elayers-sd", |
|
default=4, |
|
type=int, |
|
help="Number of speaker differentiate encoder layers" |
|
"for multi-speaker speech recognition task.", |
|
) |
|
return parser |
|
|
|
def get_total_subsampling_factor(self): |
|
"""Get total subsampling factor.""" |
|
return self.enc.conv_subsampling_factor * int(np.prod(self.subsample)) |
|
|
|
def __init__(self, idim, odim, args): |
|
"""Initialize multi-speaker E2E module.""" |
|
super(E2E, self).__init__() |
|
torch.nn.Module.__init__(self) |
|
self.mtlalpha = args.mtlalpha |
|
assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" |
|
self.etype = args.etype |
|
self.verbose = args.verbose |
|
|
|
args.char_list = getattr(args, "char_list", None) |
|
self.char_list = args.char_list |
|
self.outdir = args.outdir |
|
self.space = args.sym_space |
|
self.blank = args.sym_blank |
|
self.reporter = Reporter() |
|
self.num_spkrs = args.num_spkrs |
|
self.spa = args.spa |
|
self.pit = PIT(self.num_spkrs) |
|
|
|
|
|
|
|
self.sos = odim - 1 |
|
self.eos = odim - 1 |
|
|
|
|
|
self.subsample = get_subsample(args, mode="asr", arch="rnn_mix") |
|
|
|
|
|
if args.lsm_type and os.path.isfile(args.train_json): |
|
logging.info("Use label smoothing with " + args.lsm_type) |
|
labeldist = label_smoothing_dist( |
|
odim, args.lsm_type, transcript=args.train_json |
|
) |
|
else: |
|
labeldist = None |
|
|
|
if getattr(args, "use_frontend", False): |
|
self.frontend = frontend_for(args, idim) |
|
self.feature_transform = feature_transform_for(args, (idim - 1) * 2) |
|
idim = args.n_mels |
|
else: |
|
self.frontend = None |
|
|
|
|
|
self.enc = encoder_for(args, idim, self.subsample) |
|
|
|
self.ctc = ctc_for(args, odim, reduce=False) |
|
|
|
num_att = self.num_spkrs if args.spa else 1 |
|
self.att = att_for(args, num_att) |
|
|
|
self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) |
|
|
|
|
|
self.init_like_chainer() |
|
|
|
|
|
if "report_cer" in vars(args) and (args.report_cer or args.report_wer): |
|
recog_args = { |
|
"beam_size": args.beam_size, |
|
"penalty": args.penalty, |
|
"ctc_weight": args.ctc_weight, |
|
"maxlenratio": args.maxlenratio, |
|
"minlenratio": args.minlenratio, |
|
"lm_weight": args.lm_weight, |
|
"rnnlm": args.rnnlm, |
|
"nbest": args.nbest, |
|
"space": args.sym_space, |
|
"blank": args.sym_blank, |
|
} |
|
|
|
self.recog_args = argparse.Namespace(**recog_args) |
|
self.report_cer = args.report_cer |
|
self.report_wer = args.report_wer |
|
else: |
|
self.report_cer = False |
|
self.report_wer = False |
|
self.rnnlm = None |
|
|
|
self.logzero = -10000000000.0 |
|
self.loss = None |
|
self.acc = None |
|
|
|
def init_like_chainer(self): |
|
"""Initialize weight like chainer. |
|
|
|
chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 |
|
pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) |
|
|
|
however, there are two exceptions as far as I know. |
|
- EmbedID.W ~ Normal(0, 1) |
|
- LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) |
|
""" |
|
lecun_normal_init_parameters(self) |
|
|
|
|
|
self.dec.embed.weight.data.normal_(0, 1) |
|
|
|
|
|
for i in six.moves.range(len(self.dec.decoder)): |
|
set_forget_bias_to_one(self.dec.decoder[i].bias_ih) |
|
|
|
def forward(self, xs_pad, ilens, ys_pad): |
|
"""E2E forward. |
|
|
|
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) |
|
:param torch.Tensor ilens: batch of lengths of input sequences (B) |
|
:param torch.Tensor ys_pad: |
|
batch of padded character id sequence tensor (B, num_spkrs, Lmax) |
|
:return: ctc loss value |
|
:rtype: torch.Tensor |
|
:return: attention loss value |
|
:rtype: torch.Tensor |
|
:return: accuracy in attention decoder |
|
:rtype: float |
|
""" |
|
|
|
if self.frontend is not None: |
|
hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) |
|
if isinstance(hs_pad, list): |
|
hlens_n = [None] * self.num_spkrs |
|
for i in range(self.num_spkrs): |
|
hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens) |
|
hlens = hlens_n |
|
else: |
|
hs_pad, hlens = self.feature_transform(hs_pad, hlens) |
|
else: |
|
hs_pad, hlens = xs_pad, ilens |
|
|
|
|
|
if not isinstance( |
|
hs_pad, list |
|
): |
|
hs_pad, hlens, _ = self.enc(hs_pad, hlens) |
|
else: |
|
for i in range(self.num_spkrs): |
|
hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i]) |
|
|
|
|
|
if self.mtlalpha == 0: |
|
loss_ctc, min_perm = None, None |
|
else: |
|
if not isinstance(hs_pad, list): |
|
loss_ctc = torch.mean(self.ctc(hs_pad, hlens, ys_pad)) |
|
else: |
|
ys_pad = ys_pad.transpose(0, 1) |
|
loss_ctc_perm = torch.stack( |
|
[ |
|
self.ctc( |
|
hs_pad[i // self.num_spkrs], |
|
hlens[i // self.num_spkrs], |
|
ys_pad[i % self.num_spkrs], |
|
) |
|
for i in range(self.num_spkrs ** 2) |
|
], |
|
dim=1, |
|
) |
|
loss_ctc, min_perm = self.pit.pit_process(loss_ctc_perm) |
|
logging.info("ctc loss:" + str(float(loss_ctc))) |
|
|
|
|
|
if self.mtlalpha == 1: |
|
loss_att = None |
|
acc = None |
|
else: |
|
if not isinstance(hs_pad, list): |
|
loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad) |
|
else: |
|
for i in range(ys_pad.size(1)): |
|
ys_pad[:, i] = ys_pad[min_perm[i], i] |
|
rslt = [ |
|
self.dec(hs_pad[i], hlens[i], ys_pad[i], strm_idx=i) |
|
for i in range(self.num_spkrs) |
|
] |
|
loss_att = sum([r[0] for r in rslt]) / float(len(rslt)) |
|
acc = sum([r[1] for r in rslt]) / float(len(rslt)) |
|
self.acc = acc |
|
|
|
|
|
if self.mtlalpha == 0 or self.char_list is None: |
|
cer_ctc = None |
|
else: |
|
cers = [] |
|
for ns in range(self.num_spkrs): |
|
y_hats = self.ctc.argmax(hs_pad[ns]).data |
|
for i, y in enumerate(y_hats): |
|
y_hat = [x[0] for x in groupby(y)] |
|
y_true = ys_pad[ns][i] |
|
|
|
seq_hat = [ |
|
self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 |
|
] |
|
seq_true = [ |
|
self.char_list[int(idx)] for idx in y_true if int(idx) != -1 |
|
] |
|
seq_hat_text = "".join(seq_hat).replace(self.space, " ") |
|
seq_hat_text = seq_hat_text.replace(self.blank, "") |
|
seq_true_text = "".join(seq_true).replace(self.space, " ") |
|
|
|
hyp_chars = seq_hat_text.replace(" ", "") |
|
ref_chars = seq_true_text.replace(" ", "") |
|
if len(ref_chars) > 0: |
|
cers.append( |
|
editdistance.eval(hyp_chars, ref_chars) / len(ref_chars) |
|
) |
|
|
|
cer_ctc = sum(cers) / len(cers) if cers else None |
|
|
|
|
|
if ( |
|
self.training |
|
or not (self.report_cer or self.report_wer) |
|
or not isinstance(hs_pad, list) |
|
): |
|
cer, wer = 0.0, 0.0 |
|
else: |
|
if self.recog_args.ctc_weight > 0.0: |
|
lpz = [ |
|
self.ctc.log_softmax(hs_pad[i]).data for i in range(self.num_spkrs) |
|
] |
|
else: |
|
lpz = None |
|
|
|
word_eds, char_eds, word_ref_lens, char_ref_lens = [], [], [], [] |
|
nbest_hyps = [ |
|
self.dec.recognize_beam_batch( |
|
hs_pad[i], |
|
torch.tensor(hlens[i]), |
|
lpz[i], |
|
self.recog_args, |
|
self.char_list, |
|
self.rnnlm, |
|
strm_idx=i, |
|
) |
|
for i in range(self.num_spkrs) |
|
] |
|
|
|
y_hats = [ |
|
[nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps[i]] |
|
for i in range(self.num_spkrs) |
|
] |
|
for i in range(len(y_hats[0])): |
|
hyp_words = [] |
|
hyp_chars = [] |
|
ref_words = [] |
|
ref_chars = [] |
|
for ns in range(self.num_spkrs): |
|
y_hat = y_hats[ns][i] |
|
y_true = ys_pad[ns][i] |
|
|
|
seq_hat = [ |
|
self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 |
|
] |
|
seq_true = [ |
|
self.char_list[int(idx)] for idx in y_true if int(idx) != -1 |
|
] |
|
seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, " ") |
|
seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "") |
|
seq_true_text = "".join(seq_true).replace( |
|
self.recog_args.space, " " |
|
) |
|
|
|
hyp_words.append(seq_hat_text.split()) |
|
ref_words.append(seq_true_text.split()) |
|
hyp_chars.append(seq_hat_text.replace(" ", "")) |
|
ref_chars.append(seq_true_text.replace(" ", "")) |
|
|
|
tmp_word_ed = [ |
|
editdistance.eval( |
|
hyp_words[ns // self.num_spkrs], ref_words[ns % self.num_spkrs] |
|
) |
|
for ns in range(self.num_spkrs ** 2) |
|
] |
|
tmp_char_ed = [ |
|
editdistance.eval( |
|
hyp_chars[ns // self.num_spkrs], ref_chars[ns % self.num_spkrs] |
|
) |
|
for ns in range(self.num_spkrs ** 2) |
|
] |
|
|
|
word_eds.append(self.pit.min_pit_sample(torch.tensor(tmp_word_ed))[0]) |
|
word_ref_lens.append(len(sum(ref_words, []))) |
|
char_eds.append(self.pit.min_pit_sample(torch.tensor(tmp_char_ed))[0]) |
|
char_ref_lens.append(len("".join(ref_chars))) |
|
|
|
wer = ( |
|
0.0 |
|
if not self.report_wer |
|
else float(sum(word_eds)) / sum(word_ref_lens) |
|
) |
|
cer = ( |
|
0.0 |
|
if not self.report_cer |
|
else float(sum(char_eds)) / sum(char_ref_lens) |
|
) |
|
|
|
alpha = self.mtlalpha |
|
if alpha == 0: |
|
self.loss = loss_att |
|
loss_att_data = float(loss_att) |
|
loss_ctc_data = None |
|
elif alpha == 1: |
|
self.loss = loss_ctc |
|
loss_att_data = None |
|
loss_ctc_data = float(loss_ctc) |
|
else: |
|
self.loss = alpha * loss_ctc + (1 - alpha) * loss_att |
|
loss_att_data = float(loss_att) |
|
loss_ctc_data = float(loss_ctc) |
|
|
|
loss_data = float(self.loss) |
|
if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): |
|
self.reporter.report( |
|
loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data |
|
) |
|
else: |
|
logging.warning("loss (=%f) is not correct", loss_data) |
|
return self.loss |
|
|
|
def recognize(self, x, recog_args, char_list, rnnlm=None): |
|
"""E2E beam search. |
|
|
|
:param ndarray x: input acoustic feature (T, D) |
|
:param Namespace recog_args: argument Namespace containing options |
|
:param list char_list: list of characters |
|
:param torch.nn.Module rnnlm: language model module |
|
:return: N-best decoding results |
|
:rtype: list |
|
""" |
|
prev = self.training |
|
self.eval() |
|
ilens = [x.shape[0]] |
|
|
|
|
|
x = x[:: self.subsample[0], :] |
|
h = to_device(self, to_torch_tensor(x).float()) |
|
|
|
hs = h.contiguous().unsqueeze(0) |
|
|
|
|
|
if self.frontend is not None: |
|
hs, hlens, mask = self.frontend(hs, ilens) |
|
hlens_n = [None] * self.num_spkrs |
|
for i in range(self.num_spkrs): |
|
hs[i], hlens_n[i] = self.feature_transform(hs[i], hlens) |
|
hlens = hlens_n |
|
else: |
|
hs, hlens = hs, ilens |
|
|
|
|
|
if not isinstance(hs, list): |
|
hs, hlens, _ = self.enc(hs, hlens) |
|
else: |
|
for i in range(self.num_spkrs): |
|
hs[i], hlens[i], _ = self.enc(hs[i], hlens[i]) |
|
|
|
|
|
if recog_args.ctc_weight > 0.0: |
|
lpz = [self.ctc.log_softmax(i)[0] for i in hs] |
|
else: |
|
lpz = None |
|
|
|
|
|
|
|
y = [ |
|
self.dec.recognize_beam( |
|
hs[i][0], lpz[i], recog_args, char_list, rnnlm, strm_idx=i |
|
) |
|
for i in range(self.num_spkrs) |
|
] |
|
|
|
if prev: |
|
self.train() |
|
return y |
|
|
|
def recognize_batch(self, xs, recog_args, char_list, rnnlm=None): |
|
"""E2E beam search. |
|
|
|
:param ndarray xs: input acoustic feature (T, D) |
|
:param Namespace recog_args: argument Namespace containing options |
|
:param list char_list: list of characters |
|
:param torch.nn.Module rnnlm: language model module |
|
:return: N-best decoding results |
|
:rtype: list |
|
""" |
|
prev = self.training |
|
self.eval() |
|
ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) |
|
|
|
|
|
xs = [xx[:: self.subsample[0], :] for xx in xs] |
|
xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] |
|
xs_pad = pad_list(xs, 0.0) |
|
|
|
|
|
if self.frontend is not None: |
|
hs_pad, hlens, mask = self.frontend(xs_pad, ilens) |
|
hlens_n = [None] * self.num_spkrs |
|
for i in range(self.num_spkrs): |
|
hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens) |
|
hlens = hlens_n |
|
else: |
|
hs_pad, hlens = xs_pad, ilens |
|
|
|
|
|
if not isinstance(hs_pad, list): |
|
hs_pad, hlens, _ = self.enc(hs_pad, hlens) |
|
else: |
|
for i in range(self.num_spkrs): |
|
hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i]) |
|
|
|
|
|
if recog_args.ctc_weight > 0.0: |
|
lpz = [self.ctc.log_softmax(hs_pad[i]) for i in range(self.num_spkrs)] |
|
normalize_score = False |
|
else: |
|
lpz = None |
|
normalize_score = True |
|
|
|
|
|
y = [ |
|
self.dec.recognize_beam_batch( |
|
hs_pad[i], |
|
hlens[i], |
|
lpz[i], |
|
recog_args, |
|
char_list, |
|
rnnlm, |
|
normalize_score=normalize_score, |
|
strm_idx=i, |
|
) |
|
for i in range(self.num_spkrs) |
|
] |
|
|
|
if prev: |
|
self.train() |
|
return y |
|
|
|
def enhance(self, xs): |
|
"""Forward only the frontend stage. |
|
|
|
:param ndarray xs: input acoustic feature (T, C, F) |
|
""" |
|
if self.frontend is None: |
|
raise RuntimeError("Frontend doesn't exist") |
|
prev = self.training |
|
self.eval() |
|
ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) |
|
|
|
|
|
xs = [xx[:: self.subsample[0], :] for xx in xs] |
|
xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] |
|
xs_pad = pad_list(xs, 0.0) |
|
enhanced, hlensm, mask = self.frontend(xs_pad, ilens) |
|
if prev: |
|
self.train() |
|
|
|
if isinstance(enhanced, (tuple, list)): |
|
enhanced = list(enhanced) |
|
mask = list(mask) |
|
for idx in range(len(enhanced)): |
|
enhanced[idx] = enhanced[idx].cpu().numpy() |
|
mask[idx] = mask[idx].cpu().numpy() |
|
return enhanced, mask, ilens |
|
return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens |
|
|
|
def calculate_all_attentions(self, xs_pad, ilens, ys_pad): |
|
"""E2E attention calculation. |
|
|
|
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) |
|
:param torch.Tensor ilens: batch of lengths of input sequences (B) |
|
:param torch.Tensor ys_pad: |
|
batch of padded character id sequence tensor (B, num_spkrs, Lmax) |
|
:return: attention weights with the following shape, |
|
1) multi-head case => attention weights (B, H, Lmax, Tmax), |
|
2) other case => attention weights (B, Lmax, Tmax). |
|
:rtype: float ndarray |
|
""" |
|
with torch.no_grad(): |
|
|
|
if self.frontend is not None: |
|
hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) |
|
hlens_n = [None] * self.num_spkrs |
|
for i in range(self.num_spkrs): |
|
hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens) |
|
hlens = hlens_n |
|
else: |
|
hs_pad, hlens = xs_pad, ilens |
|
|
|
|
|
if not isinstance(hs_pad, list): |
|
hs_pad, hlens, _ = self.enc(hs_pad, hlens) |
|
else: |
|
for i in range(self.num_spkrs): |
|
hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i]) |
|
|
|
|
|
ys_pad = ys_pad.transpose(0, 1) |
|
if self.num_spkrs <= 3: |
|
loss_ctc = torch.stack( |
|
[ |
|
self.ctc( |
|
hs_pad[i // self.num_spkrs], |
|
hlens[i // self.num_spkrs], |
|
ys_pad[i % self.num_spkrs], |
|
) |
|
for i in range(self.num_spkrs ** 2) |
|
], |
|
1, |
|
) |
|
loss_ctc, min_perm = self.pit.pit_process(loss_ctc) |
|
for i in range(ys_pad.size(1)): |
|
ys_pad[:, i] = ys_pad[min_perm[i], i] |
|
|
|
|
|
att_ws = [ |
|
self.dec.calculate_all_attentions( |
|
hs_pad[i], hlens[i], ys_pad[i], strm_idx=i |
|
) |
|
for i in range(self.num_spkrs) |
|
] |
|
|
|
return att_ws |
|
|
|
|
|
class EncoderMix(torch.nn.Module): |
|
"""Encoder module for the case of multi-speaker mixture speech. |
|
|
|
:param str etype: type of encoder network |
|
:param int idim: number of dimensions of encoder network |
|
:param int elayers_sd: |
|
number of layers of speaker differentiate part in encoder network |
|
:param int elayers_rec: |
|
number of layers of shared recognition part in encoder network |
|
:param int eunits: number of lstm units of encoder network |
|
:param int eprojs: number of projection units of encoder network |
|
:param np.ndarray subsample: list of subsampling numbers |
|
:param float dropout: dropout rate |
|
:param int in_channel: number of input channels |
|
:param int num_spkrs: number of number of speakers |
|
""" |
|
|
|
def __init__( |
|
self, |
|
etype, |
|
idim, |
|
elayers_sd, |
|
elayers_rec, |
|
eunits, |
|
eprojs, |
|
subsample, |
|
dropout, |
|
num_spkrs=2, |
|
in_channel=1, |
|
): |
|
"""Initialize the encoder of single-channel multi-speaker ASR.""" |
|
super(EncoderMix, self).__init__() |
|
typ = etype.lstrip("vgg").rstrip("p") |
|
if typ not in ["lstm", "gru", "blstm", "bgru"]: |
|
logging.error("Error: need to specify an appropriate encoder architecture") |
|
if etype.startswith("vgg"): |
|
if etype[-1] == "p": |
|
self.enc_mix = torch.nn.ModuleList([VGG2L(in_channel)]) |
|
self.enc_sd = torch.nn.ModuleList( |
|
[ |
|
torch.nn.ModuleList( |
|
[ |
|
RNNP( |
|
get_vgg2l_odim(idim, in_channel=in_channel), |
|
elayers_sd, |
|
eunits, |
|
eprojs, |
|
subsample[: elayers_sd + 1], |
|
dropout, |
|
typ=typ, |
|
) |
|
] |
|
) |
|
for i in range(num_spkrs) |
|
] |
|
) |
|
self.enc_rec = torch.nn.ModuleList( |
|
[ |
|
RNNP( |
|
eprojs, |
|
elayers_rec, |
|
eunits, |
|
eprojs, |
|
subsample[elayers_sd:], |
|
dropout, |
|
typ=typ, |
|
) |
|
] |
|
) |
|
logging.info("Use CNN-VGG + B" + typ.upper() + "P for encoder") |
|
else: |
|
logging.error( |
|
f"Error: need to specify an appropriate encoder architecture. " |
|
f"Illegal name {etype}" |
|
) |
|
sys.exit() |
|
else: |
|
logging.error( |
|
f"Error: need to specify an appropriate encoder architecture. " |
|
f"Illegal name {etype}" |
|
) |
|
sys.exit() |
|
|
|
self.num_spkrs = num_spkrs |
|
|
|
def forward(self, xs_pad, ilens): |
|
"""Encodermix forward. |
|
|
|
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) |
|
:param torch.Tensor ilens: batch of lengths of input sequences (B) |
|
:return: list: batch of hidden state sequences [num_spkrs x (B, Tmax, eprojs)] |
|
:rtype: torch.Tensor |
|
""" |
|
|
|
for module in self.enc_mix: |
|
xs_pad, ilens, _ = module(xs_pad, ilens) |
|
|
|
|
|
xs_pad_sd = [xs_pad for i in range(self.num_spkrs)] |
|
ilens_sd = [ilens for i in range(self.num_spkrs)] |
|
for ns in range(self.num_spkrs): |
|
|
|
for module in self.enc_sd[ns]: |
|
xs_pad_sd[ns], ilens_sd[ns], _ = module(xs_pad_sd[ns], ilens_sd[ns]) |
|
|
|
for module in self.enc_rec: |
|
xs_pad_sd[ns], ilens_sd[ns], _ = module(xs_pad_sd[ns], ilens_sd[ns]) |
|
|
|
|
|
mask = to_device(xs_pad, make_pad_mask(ilens_sd[0]).unsqueeze(-1)) |
|
|
|
return [x.masked_fill(mask, 0.0) for x in xs_pad_sd], ilens_sd, None |
|
|
|
|
|
def encoder_for(args, idim, subsample): |
|
"""Construct the encoder.""" |
|
if getattr(args, "use_frontend", False): |
|
|
|
return encoder_for_single(args, idim, subsample) |
|
else: |
|
return EncoderMix( |
|
args.etype, |
|
idim, |
|
args.elayers_sd, |
|
args.elayers, |
|
args.eunits, |
|
args.eprojs, |
|
subsample, |
|
args.dropout_rate, |
|
args.num_spkrs, |
|
) |
|
|