|
|
|
|
|
|
|
"""RNN sequence-to-sequence speech translation model (pytorch).""" |
|
|
|
import argparse |
|
import copy |
|
import logging |
|
import math |
|
import os |
|
|
|
import editdistance |
|
import nltk |
|
|
|
import chainer |
|
import numpy as np |
|
import six |
|
import torch |
|
|
|
from itertools import groupby |
|
|
|
from chainer import reporter |
|
|
|
from espnet.nets.e2e_asr_common import label_smoothing_dist |
|
from espnet.nets.pytorch_backend.ctc import CTC |
|
from espnet.nets.pytorch_backend.initialization import lecun_normal_init_parameters |
|
from espnet.nets.pytorch_backend.initialization import set_forget_bias_to_one |
|
from espnet.nets.pytorch_backend.nets_utils import get_subsample |
|
from espnet.nets.pytorch_backend.nets_utils import pad_list |
|
from espnet.nets.pytorch_backend.nets_utils import to_device |
|
from espnet.nets.pytorch_backend.nets_utils import to_torch_tensor |
|
from espnet.nets.pytorch_backend.rnn.argument import ( |
|
add_arguments_rnn_encoder_common, |
|
add_arguments_rnn_decoder_common, |
|
add_arguments_rnn_attention_common, |
|
) |
|
from espnet.nets.pytorch_backend.rnn.attentions import att_for |
|
from espnet.nets.pytorch_backend.rnn.decoders import decoder_for |
|
from espnet.nets.pytorch_backend.rnn.encoders import encoder_for |
|
from espnet.nets.st_interface import STInterface |
|
from espnet.utils.fill_missing_args import fill_missing_args |
|
|
|
CTC_LOSS_THRESHOLD = 10000 |
|
|
|
|
|
class Reporter(chainer.Chain): |
|
"""A chainer reporter wrapper.""" |
|
|
|
def report( |
|
self, |
|
loss_asr, |
|
loss_mt, |
|
loss_st, |
|
acc_asr, |
|
acc_mt, |
|
acc, |
|
cer_ctc, |
|
cer, |
|
wer, |
|
bleu, |
|
mtl_loss, |
|
): |
|
"""Report at every step.""" |
|
reporter.report({"loss_asr": loss_asr}, self) |
|
reporter.report({"loss_mt": loss_mt}, self) |
|
reporter.report({"loss_st": loss_st}, self) |
|
reporter.report({"acc_asr": acc_asr}, self) |
|
reporter.report({"acc_mt": acc_mt}, self) |
|
reporter.report({"acc": acc}, self) |
|
reporter.report({"cer_ctc": cer_ctc}, self) |
|
reporter.report({"cer": cer}, self) |
|
reporter.report({"wer": wer}, self) |
|
reporter.report({"bleu": bleu}, self) |
|
logging.info("mtl loss:" + str(mtl_loss)) |
|
reporter.report({"loss": mtl_loss}, self) |
|
|
|
|
|
class E2E(STInterface, torch.nn.Module): |
|
"""E2E module. |
|
|
|
:param int idim: dimension of inputs |
|
:param int odim: dimension of outputs |
|
:param Namespace args: argument Namespace containing options |
|
|
|
""" |
|
|
|
@staticmethod |
|
def add_arguments(parser): |
|
"""Add arguments.""" |
|
E2E.encoder_add_arguments(parser) |
|
E2E.attention_add_arguments(parser) |
|
E2E.decoder_add_arguments(parser) |
|
return parser |
|
|
|
@staticmethod |
|
def encoder_add_arguments(parser): |
|
"""Add arguments for the encoder.""" |
|
group = parser.add_argument_group("E2E encoder setting") |
|
group = add_arguments_rnn_encoder_common(group) |
|
return parser |
|
|
|
@staticmethod |
|
def attention_add_arguments(parser): |
|
"""Add arguments for the attention.""" |
|
group = parser.add_argument_group("E2E attention setting") |
|
group = add_arguments_rnn_attention_common(group) |
|
return parser |
|
|
|
@staticmethod |
|
def decoder_add_arguments(parser): |
|
"""Add arguments for the decoder.""" |
|
group = parser.add_argument_group("E2E decoder setting") |
|
group = add_arguments_rnn_decoder_common(group) |
|
return parser |
|
|
|
def get_total_subsampling_factor(self): |
|
"""Get total subsampling factor.""" |
|
return self.enc.conv_subsampling_factor * int(np.prod(self.subsample)) |
|
|
|
def __init__(self, idim, odim, args): |
|
"""Construct an E2E object. |
|
|
|
:param int idim: dimension of inputs |
|
:param int odim: dimension of outputs |
|
:param Namespace args: argument Namespace containing options |
|
""" |
|
super(E2E, self).__init__() |
|
torch.nn.Module.__init__(self) |
|
|
|
|
|
args = fill_missing_args(args, self.add_arguments) |
|
|
|
self.asr_weight = args.asr_weight |
|
self.mt_weight = args.mt_weight |
|
self.mtlalpha = args.mtlalpha |
|
assert 0.0 <= self.asr_weight < 1.0, "asr_weight should be [0.0, 1.0)" |
|
assert 0.0 <= self.mt_weight < 1.0, "mt_weight should be [0.0, 1.0)" |
|
assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" |
|
self.etype = args.etype |
|
self.verbose = args.verbose |
|
|
|
args.char_list = getattr(args, "char_list", None) |
|
self.char_list = args.char_list |
|
self.outdir = args.outdir |
|
self.space = args.sym_space |
|
self.blank = args.sym_blank |
|
self.reporter = Reporter() |
|
|
|
|
|
|
|
self.sos = odim - 1 |
|
self.eos = odim - 1 |
|
self.pad = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.subsample = get_subsample(args, mode="st", arch="rnn") |
|
|
|
|
|
if args.lsm_type and os.path.isfile(args.train_json): |
|
logging.info("Use label smoothing with " + args.lsm_type) |
|
labeldist = label_smoothing_dist( |
|
odim, args.lsm_type, transcript=args.train_json |
|
) |
|
else: |
|
labeldist = None |
|
|
|
|
|
self.multilingual = getattr(args, "multilingual", False) |
|
self.replace_sos = getattr(args, "replace_sos", False) |
|
|
|
|
|
self.enc = encoder_for(args, idim, self.subsample) |
|
|
|
self.att = att_for(args) |
|
|
|
self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) |
|
|
|
|
|
self.ctc = None |
|
self.att_asr = None |
|
self.dec_asr = None |
|
if self.asr_weight > 0: |
|
if self.mtlalpha > 0.0: |
|
self.ctc = CTC( |
|
odim, |
|
args.eprojs, |
|
args.dropout_rate, |
|
ctc_type=args.ctc_type, |
|
reduce=True, |
|
) |
|
if self.mtlalpha < 1.0: |
|
|
|
self.att_asr = att_for(args) |
|
|
|
args_asr = copy.deepcopy(args) |
|
args_asr.atype = "location" |
|
self.dec_asr = decoder_for( |
|
args_asr, odim, self.sos, self.eos, self.att_asr, labeldist |
|
) |
|
|
|
|
|
if self.mt_weight > 0: |
|
self.embed_mt = torch.nn.Embedding(odim, args.eunits, padding_idx=self.pad) |
|
self.dropout_mt = torch.nn.Dropout(p=args.dropout_rate) |
|
self.enc_mt = encoder_for( |
|
args, args.eunits, subsample=np.ones(args.elayers + 1, dtype=np.int) |
|
) |
|
|
|
|
|
self.init_like_chainer() |
|
|
|
|
|
if self.asr_weight > 0 and args.report_cer or args.report_wer: |
|
recog_args = { |
|
"beam_size": args.beam_size, |
|
"penalty": args.penalty, |
|
"ctc_weight": args.ctc_weight, |
|
"maxlenratio": args.maxlenratio, |
|
"minlenratio": args.minlenratio, |
|
"lm_weight": args.lm_weight, |
|
"rnnlm": args.rnnlm, |
|
"nbest": args.nbest, |
|
"space": args.sym_space, |
|
"blank": args.sym_blank, |
|
"tgt_lang": False, |
|
} |
|
|
|
self.recog_args = argparse.Namespace(**recog_args) |
|
self.report_cer = args.report_cer |
|
self.report_wer = args.report_wer |
|
else: |
|
self.report_cer = False |
|
self.report_wer = False |
|
if args.report_bleu: |
|
trans_args = { |
|
"beam_size": args.beam_size, |
|
"penalty": args.penalty, |
|
"ctc_weight": 0, |
|
"maxlenratio": args.maxlenratio, |
|
"minlenratio": args.minlenratio, |
|
"lm_weight": args.lm_weight, |
|
"rnnlm": args.rnnlm, |
|
"nbest": args.nbest, |
|
"space": args.sym_space, |
|
"blank": args.sym_blank, |
|
"tgt_lang": False, |
|
} |
|
|
|
self.trans_args = argparse.Namespace(**trans_args) |
|
self.report_bleu = args.report_bleu |
|
else: |
|
self.report_bleu = False |
|
self.rnnlm = None |
|
|
|
self.logzero = -10000000000.0 |
|
self.loss = None |
|
self.acc = None |
|
|
|
def init_like_chainer(self): |
|
"""Initialize weight like chainer. |
|
|
|
chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 |
|
pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) |
|
however, there are two exceptions as far as I know. |
|
- EmbedID.W ~ Normal(0, 1) |
|
- LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) |
|
""" |
|
lecun_normal_init_parameters(self) |
|
|
|
|
|
self.dec.embed.weight.data.normal_(0, 1) |
|
|
|
|
|
for i in six.moves.range(len(self.dec.decoder)): |
|
set_forget_bias_to_one(self.dec.decoder[i].bias_ih) |
|
|
|
def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): |
|
"""E2E forward. |
|
|
|
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) |
|
:param torch.Tensor ilens: batch of lengths of input sequences (B) |
|
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) |
|
:return: loss value |
|
:rtype: torch.Tensor |
|
""" |
|
|
|
if self.multilingual: |
|
tgt_lang_ids = ys_pad[:, 0:1] |
|
ys_pad = ys_pad[:, 1:] |
|
else: |
|
tgt_lang_ids = None |
|
|
|
|
|
hs_pad, hlens, _ = self.enc(xs_pad, ilens) |
|
|
|
|
|
self.loss_st, self.acc, _ = self.dec( |
|
hs_pad, hlens, ys_pad, lang_ids=tgt_lang_ids |
|
) |
|
|
|
|
|
( |
|
self.loss_asr_att, |
|
acc_asr, |
|
self.loss_asr_ctc, |
|
cer_ctc, |
|
cer, |
|
wer, |
|
) = self.forward_asr(hs_pad, hlens, ys_pad_src) |
|
|
|
|
|
self.loss_mt, acc_mt = self.forward_mt(ys_pad, ys_pad_src) |
|
|
|
|
|
if self.training or not self.report_bleu: |
|
self.bleu = 0.0 |
|
else: |
|
lpz = None |
|
|
|
nbest_hyps = self.dec.recognize_beam_batch( |
|
hs_pad, |
|
torch.tensor(hlens), |
|
lpz, |
|
self.trans_args, |
|
self.char_list, |
|
self.rnnlm, |
|
lang_ids=tgt_lang_ids.squeeze(1).tolist() |
|
if self.multilingual |
|
else None, |
|
) |
|
|
|
list_of_refs = [] |
|
hyps = [] |
|
y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps] |
|
for i, y_hat in enumerate(y_hats): |
|
y_true = ys_pad[i] |
|
|
|
seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1] |
|
seq_true = [ |
|
self.char_list[int(idx)] for idx in y_true if int(idx) != -1 |
|
] |
|
seq_hat_text = "".join(seq_hat).replace(self.trans_args.space, " ") |
|
seq_hat_text = seq_hat_text.replace(self.trans_args.blank, "") |
|
seq_true_text = "".join(seq_true).replace(self.trans_args.space, " ") |
|
|
|
hyps += [seq_hat_text.split(" ")] |
|
list_of_refs += [[seq_true_text.split(" ")]] |
|
|
|
self.bleu = nltk.bleu_score.corpus_bleu(list_of_refs, hyps) * 100 |
|
|
|
asr_ctc_weight = self.mtlalpha |
|
self.loss_asr = ( |
|
asr_ctc_weight * self.loss_asr_ctc |
|
+ (1 - asr_ctc_weight) * self.loss_asr_att |
|
) |
|
self.loss = ( |
|
(1 - self.asr_weight - self.mt_weight) * self.loss_st |
|
+ self.asr_weight * self.loss_asr |
|
+ self.mt_weight * self.loss_mt |
|
) |
|
loss_st_data = float(self.loss_st) |
|
loss_asr_data = float(self.loss_asr) |
|
loss_mt_data = float(self.loss_mt) |
|
loss_data = float(self.loss) |
|
if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): |
|
self.reporter.report( |
|
loss_asr_data, |
|
loss_mt_data, |
|
loss_st_data, |
|
acc_asr, |
|
acc_mt, |
|
self.acc, |
|
cer_ctc, |
|
cer, |
|
wer, |
|
self.bleu, |
|
loss_data, |
|
) |
|
else: |
|
logging.warning("loss (=%f) is not correct", loss_data) |
|
return self.loss |
|
|
|
def forward_asr(self, hs_pad, hlens, ys_pad): |
|
"""Forward pass in the auxiliary ASR task. |
|
|
|
:param torch.Tensor hs_pad: batch of padded source sequences (B, Tmax, idim) |
|
:param torch.Tensor hlens: batch of lengths of input sequences (B) |
|
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) |
|
:return: ASR attention loss value |
|
:rtype: torch.Tensor |
|
:return: accuracy in ASR attention decoder |
|
:rtype: float |
|
:return: ASR CTC loss value |
|
:rtype: torch.Tensor |
|
:return: character error rate from CTC prediction |
|
:rtype: float |
|
:return: character error rate from attetion decoder prediction |
|
:rtype: float |
|
:return: word error rate from attetion decoder prediction |
|
:rtype: float |
|
""" |
|
loss_att, loss_ctc = 0.0, 0.0 |
|
acc = None |
|
cer, wer = None, None |
|
cer_ctc = None |
|
if self.asr_weight == 0: |
|
return loss_att, acc, loss_ctc, cer_ctc, cer, wer |
|
|
|
|
|
if self.mtlalpha < 1: |
|
loss_asr, acc_asr, _ = self.dec_asr(hs_pad, hlens, ys_pad) |
|
|
|
|
|
if not self.training and (self.report_cer or self.report_wer): |
|
if self.mtlalpha > 0 and self.recog_args.ctc_weight > 0.0: |
|
lpz = self.ctc.log_softmax(hs_pad).data |
|
else: |
|
lpz = None |
|
|
|
word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] |
|
nbest_hyps_asr = self.dec_asr.recognize_beam_batch( |
|
hs_pad, |
|
torch.tensor(hlens), |
|
lpz, |
|
self.recog_args, |
|
self.char_list, |
|
self.rnnlm, |
|
) |
|
|
|
y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps_asr] |
|
for i, y_hat in enumerate(y_hats): |
|
y_true = ys_pad[i] |
|
|
|
seq_hat = [ |
|
self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 |
|
] |
|
seq_true = [ |
|
self.char_list[int(idx)] for idx in y_true if int(idx) != -1 |
|
] |
|
seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, " ") |
|
seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "") |
|
seq_true_text = "".join(seq_true).replace( |
|
self.recog_args.space, " " |
|
) |
|
|
|
hyp_words = seq_hat_text.split() |
|
ref_words = seq_true_text.split() |
|
word_eds.append(editdistance.eval(hyp_words, ref_words)) |
|
word_ref_lens.append(len(ref_words)) |
|
hyp_chars = seq_hat_text.replace(" ", "") |
|
ref_chars = seq_true_text.replace(" ", "") |
|
char_eds.append(editdistance.eval(hyp_chars, ref_chars)) |
|
char_ref_lens.append(len(ref_chars)) |
|
|
|
wer = ( |
|
0.0 |
|
if not self.report_wer |
|
else float(sum(word_eds)) / sum(word_ref_lens) |
|
) |
|
cer = ( |
|
0.0 |
|
if not self.report_cer |
|
else float(sum(char_eds)) / sum(char_ref_lens) |
|
) |
|
|
|
|
|
if self.mtlalpha > 0: |
|
loss_ctc = self.ctc(hs_pad, hlens, ys_pad) |
|
|
|
|
|
if self.char_list is not None: |
|
cers = [] |
|
y_hats = self.ctc.argmax(hs_pad).data |
|
for i, y in enumerate(y_hats): |
|
y_hat = [x[0] for x in groupby(y)] |
|
y_true = ys_pad[i] |
|
|
|
seq_hat = [ |
|
self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 |
|
] |
|
seq_true = [ |
|
self.char_list[int(idx)] for idx in y_true if int(idx) != -1 |
|
] |
|
seq_hat_text = "".join(seq_hat).replace(self.space, " ") |
|
seq_hat_text = seq_hat_text.replace(self.blank, "") |
|
seq_true_text = "".join(seq_true).replace(self.space, " ") |
|
|
|
hyp_chars = seq_hat_text.replace(" ", "") |
|
ref_chars = seq_true_text.replace(" ", "") |
|
if len(ref_chars) > 0: |
|
cers.append( |
|
editdistance.eval(hyp_chars, ref_chars) / len(ref_chars) |
|
) |
|
cer_ctc = sum(cers) / len(cers) if cers else None |
|
|
|
return loss_att, acc, loss_ctc, cer_ctc, cer, wer |
|
|
|
def forward_mt(self, xs_pad, ys_pad): |
|
"""Forward pass in the auxiliary MT task. |
|
|
|
:param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) |
|
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) |
|
:return: MT loss value |
|
:rtype: torch.Tensor |
|
:return: accuracy in MT decoder |
|
:rtype: float |
|
""" |
|
loss = 0.0 |
|
acc = 0.0 |
|
if self.mt_weight == 0: |
|
return loss, acc |
|
|
|
ilens = torch.sum(xs_pad != -1, dim=1).cpu().numpy() |
|
|
|
ys_src = [y[y != -1] for y in xs_pad] |
|
xs_zero_pad = pad_list(ys_src, self.pad) |
|
hs_pad, hlens, _ = self.enc_mt( |
|
self.dropout_mt(self.embed_mt(xs_zero_pad)), ilens |
|
) |
|
loss, acc, _ = self.dec(hs_pad, hlens, ys_pad) |
|
return loss, acc |
|
|
|
def scorers(self): |
|
"""Scorers.""" |
|
return dict(decoder=self.dec) |
|
|
|
def encode(self, x): |
|
"""Encode acoustic features. |
|
|
|
:param ndarray x: input acoustic feature (T, D) |
|
:return: encoder outputs |
|
:rtype: torch.Tensor |
|
""" |
|
self.eval() |
|
ilens = [x.shape[0]] |
|
|
|
|
|
x = x[:: self.subsample[0], :] |
|
p = next(self.parameters()) |
|
h = torch.as_tensor(x, device=p.device, dtype=p.dtype) |
|
|
|
hs = h.contiguous().unsqueeze(0) |
|
|
|
|
|
hs, _, _ = self.enc(hs, ilens) |
|
return hs.squeeze(0) |
|
|
|
def translate(self, x, trans_args, char_list, rnnlm=None): |
|
"""E2E beam search. |
|
|
|
:param ndarray x: input acoustic feature (T, D) |
|
:param Namespace trans_args: argument Namespace containing options |
|
:param list char_list: list of characters |
|
:param torch.nn.Module rnnlm: language model module |
|
:return: N-best decoding results |
|
:rtype: list |
|
""" |
|
logging.info("input lengths: " + str(x.shape[0])) |
|
hs = self.encode(x).unsqueeze(0) |
|
logging.info("encoder output lengths: " + str(hs.size(1))) |
|
|
|
|
|
|
|
y = self.dec.recognize_beam(hs[0], None, trans_args, char_list, rnnlm) |
|
return y |
|
|
|
def translate_batch(self, xs, trans_args, char_list, rnnlm=None): |
|
"""E2E batch beam search. |
|
|
|
:param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...] |
|
:param Namespace trans_args: argument Namespace containing options |
|
:param list char_list: list of characters |
|
:param torch.nn.Module rnnlm: language model module |
|
:return: N-best decoding results |
|
:rtype: list |
|
""" |
|
prev = self.training |
|
self.eval() |
|
ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) |
|
|
|
|
|
xs = [xx[:: self.subsample[0], :] for xx in xs] |
|
xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] |
|
xs_pad = pad_list(xs, 0.0) |
|
|
|
|
|
hs_pad, hlens, _ = self.enc(xs_pad, ilens) |
|
|
|
|
|
hlens = torch.tensor(list(map(int, hlens))) |
|
y = self.dec.recognize_beam_batch( |
|
hs_pad, hlens, None, trans_args, char_list, rnnlm |
|
) |
|
|
|
if prev: |
|
self.train() |
|
return y |
|
|
|
def calculate_all_attentions(self, xs_pad, ilens, ys_pad, ys_pad_src): |
|
"""E2E attention calculation. |
|
|
|
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) |
|
:param torch.Tensor ilens: batch of lengths of input sequences (B) |
|
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) |
|
:param torch.Tensor ys_pad_src: |
|
batch of padded token id sequence tensor (B, Lmax) |
|
:return: attention weights with the following shape, |
|
1) multi-head case => attention weights (B, H, Lmax, Tmax), |
|
2) other case => attention weights (B, Lmax, Tmax). |
|
:rtype: float ndarray |
|
""" |
|
self.eval() |
|
with torch.no_grad(): |
|
|
|
if self.multilingual: |
|
tgt_lang_ids = ys_pad[:, 0:1] |
|
ys_pad = ys_pad[:, 1:] |
|
else: |
|
tgt_lang_ids = None |
|
hpad, hlens, _ = self.enc(xs_pad, ilens) |
|
|
|
|
|
att_ws = self.dec.calculate_all_attentions( |
|
hpad, hlens, ys_pad, lang_ids=tgt_lang_ids |
|
) |
|
self.train() |
|
return att_ws |
|
|
|
def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad, ys_pad_src): |
|
"""E2E CTC probability calculation. |
|
|
|
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) |
|
:param torch.Tensor ilens: batch of lengths of input sequences (B) |
|
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) |
|
:param torch.Tensor |
|
ys_pad_src: batch of padded token id sequence tensor (B, Lmax) |
|
:return: CTC probability (B, Tmax, vocab) |
|
:rtype: float ndarray |
|
""" |
|
probs = None |
|
if self.asr_weight == 0 or self.mtlalpha == 0: |
|
return probs |
|
|
|
self.eval() |
|
with torch.no_grad(): |
|
|
|
hpad, hlens, _ = self.enc(xs_pad, ilens) |
|
|
|
|
|
probs = self.ctc.softmax(hpad).cpu().numpy() |
|
self.train() |
|
return probs |
|
|
|
def subsample_frames(self, x): |
|
"""Subsample speeh frames in the encoder.""" |
|
|
|
x = x[:: self.subsample[0], :] |
|
ilen = [x.shape[0]] |
|
h = to_device(self, torch.from_numpy(np.array(x, dtype=np.float32))) |
|
h.contiguous() |
|
return h, ilen |
|
|