Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- encoding: utf-8 -*- | |
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | |
# MIT License (https://opensource.org/licenses/MIT) | |
import torch | |
import numpy as np | |
from dataclasses import dataclass | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
from funasr_detach.models.transducer.joint_network import JointNetwork | |
class Hypothesis: | |
"""Default hypothesis definition for Transducer search algorithms. | |
Args: | |
score: Total log-probability. | |
yseq: Label sequence as integer ID sequence. | |
dec_state: RNNDecoder or StatelessDecoder state. | |
((N, 1, D_dec), (N, 1, D_dec) or None) or None | |
lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None | |
""" | |
score: float | |
yseq: List[int] | |
dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None | |
lm_state: Optional[Union[Dict[str, Any], List[Any]]] = None | |
class ExtendedHypothesis(Hypothesis): | |
"""Extended hypothesis definition for NSC beam search and mAES. | |
Args: | |
: Hypothesis dataclass arguments. | |
dec_out: Decoder output sequence. (B, D_dec) | |
lm_score: Log-probabilities of the LM for given label. (vocab_size) | |
""" | |
dec_out: torch.Tensor = None | |
lm_score: torch.Tensor = None | |
class BeamSearchTransducer: | |
"""Beam search implementation for Transducer. | |
Args: | |
decoder: Decoder module. | |
joint_network: Joint network module. | |
beam_size: Size of the beam. | |
lm: LM class. | |
lm_weight: LM weight for soft fusion. | |
search_type: Search algorithm to use during inference. | |
max_sym_exp: Number of maximum symbol expansions at each time step. (TSD) | |
u_max: Maximum expected target sequence length. (ALSD) | |
nstep: Number of maximum expansion steps at each time step. (mAES) | |
expansion_gamma: Allowed logp difference for prune-by-value method. (mAES) | |
expansion_beta: | |
Number of additional candidates for expanded hypotheses selection. (mAES) | |
score_norm: Normalize final scores by length. | |
nbest: Number of final hypothesis. | |
streaming: Whether to perform chunk-by-chunk beam search. | |
""" | |
def __init__( | |
self, | |
decoder, | |
joint_network: JointNetwork, | |
beam_size: int, | |
lm: Optional[torch.nn.Module] = None, | |
lm_weight: float = 0.1, | |
search_type: str = "default", | |
max_sym_exp: int = 3, | |
u_max: int = 50, | |
nstep: int = 2, | |
expansion_gamma: float = 2.3, | |
expansion_beta: int = 2, | |
score_norm: bool = False, | |
nbest: int = 1, | |
streaming: bool = False, | |
) -> None: | |
"""Construct a BeamSearchTransducer object.""" | |
super().__init__() | |
self.decoder = decoder | |
self.joint_network = joint_network | |
self.vocab_size = decoder.vocab_size | |
assert beam_size <= self.vocab_size, ( | |
"beam_size (%d) should be smaller than or equal to vocabulary size (%d)." | |
% ( | |
beam_size, | |
self.vocab_size, | |
) | |
) | |
self.beam_size = beam_size | |
if search_type == "default": | |
self.search_algorithm = self.default_beam_search | |
elif search_type == "tsd": | |
assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % ( | |
max_sym_exp | |
) | |
self.max_sym_exp = max_sym_exp | |
self.search_algorithm = self.time_sync_decoding | |
elif search_type == "alsd": | |
assert not streaming, "ALSD is not available in streaming mode." | |
assert u_max >= 0, "u_max should be a positive integer, a portion of max_T." | |
self.u_max = u_max | |
self.search_algorithm = self.align_length_sync_decoding | |
elif search_type == "maes": | |
assert self.vocab_size >= beam_size + expansion_beta, ( | |
"beam_size (%d) + expansion_beta (%d) " | |
" should be smaller than or equal to vocab size (%d)." | |
% (beam_size, expansion_beta, self.vocab_size) | |
) | |
self.max_candidates = beam_size + expansion_beta | |
self.nstep = nstep | |
self.expansion_gamma = expansion_gamma | |
self.search_algorithm = self.modified_adaptive_expansion_search | |
else: | |
raise NotImplementedError( | |
"Specified search type (%s) is not supported." % search_type | |
) | |
self.use_lm = lm is not None | |
if self.use_lm: | |
assert hasattr(lm, "rnn_type"), "Transformer LM is currently not supported." | |
self.sos = self.vocab_size - 1 | |
self.lm = lm | |
self.lm_weight = lm_weight | |
self.score_norm = score_norm | |
self.nbest = nbest | |
self.reset_inference_cache() | |
def __call__( | |
self, | |
enc_out: torch.Tensor, | |
is_final: bool = True, | |
) -> List[Hypothesis]: | |
"""Perform beam search. | |
Args: | |
enc_out: Encoder output sequence. (T, D_enc) | |
is_final: Whether enc_out is the final chunk of data. | |
Returns: | |
nbest_hyps: N-best decoding results | |
""" | |
self.decoder.set_device(enc_out.device) | |
hyps = self.search_algorithm(enc_out) | |
if is_final: | |
self.reset_inference_cache() | |
return self.sort_nbest(hyps) | |
self.search_cache = hyps | |
return hyps | |
def reset_inference_cache(self) -> None: | |
"""Reset cache for decoder scoring and streaming.""" | |
self.decoder.score_cache = {} | |
self.search_cache = None | |
def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]: | |
"""Sort in-place hypotheses by score or score given sequence length. | |
Args: | |
hyps: Hypothesis. | |
Return: | |
hyps: Sorted hypothesis. | |
""" | |
if self.score_norm: | |
hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True) | |
else: | |
hyps.sort(key=lambda x: x.score, reverse=True) | |
return hyps[: self.nbest] | |
def recombine_hyps(self, hyps: List[Hypothesis]) -> List[Hypothesis]: | |
"""Recombine hypotheses with same label ID sequence. | |
Args: | |
hyps: Hypotheses. | |
Returns: | |
final: Recombined hypotheses. | |
""" | |
final = {} | |
for hyp in hyps: | |
str_yseq = "_".join(map(str, hyp.yseq)) | |
if str_yseq in final: | |
final[str_yseq].score = np.logaddexp(final[str_yseq].score, hyp.score) | |
else: | |
final[str_yseq] = hyp | |
return [*final.values()] | |
def select_k_expansions( | |
self, | |
hyps: List[ExtendedHypothesis], | |
topk_idx: torch.Tensor, | |
topk_logp: torch.Tensor, | |
) -> List[ExtendedHypothesis]: | |
"""Return K hypotheses candidates for expansion from a list of hypothesis. | |
K candidates are selected according to the extended hypotheses probabilities | |
and a prune-by-value method. Where K is equal to beam_size + beta. | |
Args: | |
hyps: Hypotheses. | |
topk_idx: Indices of candidates hypothesis. | |
topk_logp: Log-probabilities of candidates hypothesis. | |
Returns: | |
k_expansions: Best K expansion hypotheses candidates. | |
""" | |
k_expansions = [] | |
for i, hyp in enumerate(hyps): | |
hyp_i = [ | |
(int(k), hyp.score + float(v)) | |
for k, v in zip(topk_idx[i], topk_logp[i]) | |
] | |
k_best_exp = max(hyp_i, key=lambda x: x[1])[1] | |
k_expansions.append( | |
sorted( | |
filter( | |
lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i | |
), | |
key=lambda x: x[1], | |
reverse=True, | |
) | |
) | |
return k_expansions | |
def create_lm_batch_inputs(self, hyps_seq: List[List[int]]) -> torch.Tensor: | |
"""Make batch of inputs with left padding for LM scoring. | |
Args: | |
hyps_seq: Hypothesis sequences. | |
Returns: | |
: Padded batch of sequences. | |
""" | |
max_len = max([len(h) for h in hyps_seq]) | |
return torch.LongTensor( | |
[[self.sos] + ([0] * (max_len - len(h))) + h[1:] for h in hyps_seq], | |
device=self.decoder.device, | |
) | |
def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]: | |
"""Beam search implementation without prefix search. | |
Modified from https://arxiv.org/pdf/1211.3711.pdf | |
Args: | |
enc_out: Encoder output sequence. (T, D) | |
Returns: | |
nbest_hyps: N-best hypothesis. | |
""" | |
beam_k = min(self.beam_size, (self.vocab_size - 1)) | |
max_t = len(enc_out) | |
if self.search_cache is not None: | |
kept_hyps = self.search_cache | |
else: | |
kept_hyps = [ | |
Hypothesis( | |
score=0.0, | |
yseq=[0], | |
dec_state=self.decoder.init_state(1), | |
) | |
] | |
for t in range(max_t): | |
hyps = kept_hyps | |
kept_hyps = [] | |
while True: | |
max_hyp = max(hyps, key=lambda x: x.score) | |
hyps.remove(max_hyp) | |
label = torch.full( | |
(1, 1), | |
max_hyp.yseq[-1], | |
dtype=torch.long, | |
device=self.decoder.device, | |
) | |
dec_out, state = self.decoder.score( | |
label, | |
max_hyp.yseq, | |
max_hyp.dec_state, | |
) | |
logp = torch.log_softmax( | |
self.joint_network(enc_out[t : t + 1, :], dec_out), | |
dim=-1, | |
).squeeze(0) | |
top_k = logp[1:].topk(beam_k, dim=-1) | |
kept_hyps.append( | |
Hypothesis( | |
score=(max_hyp.score + float(logp[0:1])), | |
yseq=max_hyp.yseq, | |
dec_state=max_hyp.dec_state, | |
lm_state=max_hyp.lm_state, | |
) | |
) | |
if self.use_lm: | |
lm_scores, lm_state = self.lm.score( | |
torch.LongTensor( | |
[self.sos] + max_hyp.yseq[1:], device=self.decoder.device | |
), | |
max_hyp.lm_state, | |
None, | |
) | |
else: | |
lm_state = max_hyp.lm_state | |
for logp, k in zip(*top_k): | |
score = max_hyp.score + float(logp) | |
if self.use_lm: | |
score += self.lm_weight * lm_scores[k + 1] | |
hyps.append( | |
Hypothesis( | |
score=score, | |
yseq=max_hyp.yseq + [int(k + 1)], | |
dec_state=state, | |
lm_state=lm_state, | |
) | |
) | |
hyps_max = float(max(hyps, key=lambda x: x.score).score) | |
kept_most_prob = sorted( | |
[hyp for hyp in kept_hyps if hyp.score > hyps_max], | |
key=lambda x: x.score, | |
) | |
if len(kept_most_prob) >= self.beam_size: | |
kept_hyps = kept_most_prob | |
break | |
return kept_hyps | |
def align_length_sync_decoding( | |
self, | |
enc_out: torch.Tensor, | |
) -> List[Hypothesis]: | |
"""Alignment-length synchronous beam search implementation. | |
Based on https://ieeexplore.ieee.org/document/9053040 | |
Args: | |
h: Encoder output sequences. (T, D) | |
Returns: | |
nbest_hyps: N-best hypothesis. | |
""" | |
t_max = int(enc_out.size(0)) | |
u_max = min(self.u_max, (t_max - 1)) | |
B = [Hypothesis(yseq=[0], score=0.0, dec_state=self.decoder.init_state(1))] | |
final = [] | |
if self.use_lm: | |
B[0].lm_state = self.lm.zero_state() | |
for i in range(t_max + u_max): | |
A = [] | |
B_ = [] | |
B_enc_out = [] | |
for hyp in B: | |
u = len(hyp.yseq) - 1 | |
t = i - u | |
if t > (t_max - 1): | |
continue | |
B_.append(hyp) | |
B_enc_out.append((t, enc_out[t])) | |
if B_: | |
beam_enc_out = torch.stack([b[1] for b in B_enc_out]) | |
beam_dec_out, beam_state = self.decoder.batch_score(B_) | |
beam_logp = torch.log_softmax( | |
self.joint_network(beam_enc_out, beam_dec_out), | |
dim=-1, | |
) | |
beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1) | |
if self.use_lm: | |
beam_lm_scores, beam_lm_states = self.lm.batch_score( | |
self.create_lm_batch_inputs([b.yseq for b in B_]), | |
[b.lm_state for b in B_], | |
None, | |
) | |
for i, hyp in enumerate(B_): | |
new_hyp = Hypothesis( | |
score=(hyp.score + float(beam_logp[i, 0])), | |
yseq=hyp.yseq[:], | |
dec_state=hyp.dec_state, | |
lm_state=hyp.lm_state, | |
) | |
A.append(new_hyp) | |
if B_enc_out[i][0] == (t_max - 1): | |
final.append(new_hyp) | |
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): | |
new_hyp = Hypothesis( | |
score=(hyp.score + float(logp)), | |
yseq=(hyp.yseq[:] + [int(k)]), | |
dec_state=self.decoder.select_state(beam_state, i), | |
lm_state=hyp.lm_state, | |
) | |
if self.use_lm: | |
new_hyp.score += self.lm_weight * beam_lm_scores[i, k] | |
new_hyp.lm_state = beam_lm_states[i] | |
A.append(new_hyp) | |
B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size] | |
B = self.recombine_hyps(B) | |
if final: | |
return final | |
return B | |
def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]: | |
"""Time synchronous beam search implementation. | |
Based on https://ieeexplore.ieee.org/document/9053040 | |
Args: | |
enc_out: Encoder output sequence. (T, D) | |
Returns: | |
nbest_hyps: N-best hypothesis. | |
""" | |
if self.search_cache is not None: | |
B = self.search_cache | |
else: | |
B = [ | |
Hypothesis( | |
yseq=[0], | |
score=0.0, | |
dec_state=self.decoder.init_state(1), | |
) | |
] | |
if self.use_lm: | |
B[0].lm_state = self.lm.zero_state() | |
for enc_out_t in enc_out: | |
A = [] | |
C = B | |
enc_out_t = enc_out_t.unsqueeze(0) | |
for v in range(self.max_sym_exp): | |
D = [] | |
beam_dec_out, beam_state = self.decoder.batch_score(C) | |
beam_logp = torch.log_softmax( | |
self.joint_network(enc_out_t, beam_dec_out), | |
dim=-1, | |
) | |
beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1) | |
seq_A = [h.yseq for h in A] | |
for i, hyp in enumerate(C): | |
if hyp.yseq not in seq_A: | |
A.append( | |
Hypothesis( | |
score=(hyp.score + float(beam_logp[i, 0])), | |
yseq=hyp.yseq[:], | |
dec_state=hyp.dec_state, | |
lm_state=hyp.lm_state, | |
) | |
) | |
else: | |
dict_pos = seq_A.index(hyp.yseq) | |
A[dict_pos].score = np.logaddexp( | |
A[dict_pos].score, (hyp.score + float(beam_logp[i, 0])) | |
) | |
if v < (self.max_sym_exp - 1): | |
if self.use_lm: | |
beam_lm_scores, beam_lm_states = self.lm.batch_score( | |
self.create_lm_batch_inputs([c.yseq for c in C]), | |
[c.lm_state for c in C], | |
None, | |
) | |
for i, hyp in enumerate(C): | |
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): | |
new_hyp = Hypothesis( | |
score=(hyp.score + float(logp)), | |
yseq=(hyp.yseq + [int(k)]), | |
dec_state=self.decoder.select_state(beam_state, i), | |
lm_state=hyp.lm_state, | |
) | |
if self.use_lm: | |
new_hyp.score += self.lm_weight * beam_lm_scores[i, k] | |
new_hyp.lm_state = beam_lm_states[i] | |
D.append(new_hyp) | |
C = sorted(D, key=lambda x: x.score, reverse=True)[: self.beam_size] | |
B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size] | |
return B | |
def modified_adaptive_expansion_search( | |
self, | |
enc_out: torch.Tensor, | |
) -> List[ExtendedHypothesis]: | |
"""Modified version of Adaptive Expansion Search (mAES). | |
Based on AES (https://ieeexplore.ieee.org/document/9250505) and | |
NSC (https://arxiv.org/abs/2201.05420). | |
Args: | |
enc_out: Encoder output sequence. (T, D_enc) | |
Returns: | |
nbest_hyps: N-best hypothesis. | |
""" | |
if self.search_cache is not None: | |
kept_hyps = self.search_cache | |
else: | |
init_tokens = [ | |
ExtendedHypothesis( | |
yseq=[0], | |
score=0.0, | |
dec_state=self.decoder.init_state(1), | |
) | |
] | |
beam_dec_out, beam_state = self.decoder.batch_score( | |
init_tokens, | |
) | |
if self.use_lm: | |
beam_lm_scores, beam_lm_states = self.lm.batch_score( | |
self.create_lm_batch_inputs([h.yseq for h in init_tokens]), | |
[h.lm_state for h in init_tokens], | |
None, | |
) | |
lm_state = beam_lm_states[0] | |
lm_score = beam_lm_scores[0] | |
else: | |
lm_state = None | |
lm_score = None | |
kept_hyps = [ | |
ExtendedHypothesis( | |
yseq=[0], | |
score=0.0, | |
dec_state=self.decoder.select_state(beam_state, 0), | |
dec_out=beam_dec_out[0], | |
lm_state=lm_state, | |
lm_score=lm_score, | |
) | |
] | |
for enc_out_t in enc_out: | |
hyps = kept_hyps | |
kept_hyps = [] | |
beam_enc_out = enc_out_t.unsqueeze(0) | |
list_b = [] | |
for n in range(self.nstep): | |
beam_dec_out = torch.stack([h.dec_out for h in hyps]) | |
beam_logp, beam_idx = torch.log_softmax( | |
self.joint_network(beam_enc_out, beam_dec_out), | |
dim=-1, | |
).topk(self.max_candidates, dim=-1) | |
k_expansions = self.select_k_expansions(hyps, beam_idx, beam_logp) | |
list_exp = [] | |
for i, hyp in enumerate(hyps): | |
for k, new_score in k_expansions[i]: | |
new_hyp = ExtendedHypothesis( | |
yseq=hyp.yseq[:], | |
score=new_score, | |
dec_out=hyp.dec_out, | |
dec_state=hyp.dec_state, | |
lm_state=hyp.lm_state, | |
lm_score=hyp.lm_score, | |
) | |
if k == 0: | |
list_b.append(new_hyp) | |
else: | |
new_hyp.yseq.append(int(k)) | |
if self.use_lm: | |
new_hyp.score += self.lm_weight * float(hyp.lm_score[k]) | |
list_exp.append(new_hyp) | |
if not list_exp: | |
kept_hyps = sorted( | |
self.recombine_hyps(list_b), key=lambda x: x.score, reverse=True | |
)[: self.beam_size] | |
break | |
else: | |
beam_dec_out, beam_state = self.decoder.batch_score( | |
list_exp, | |
) | |
if self.use_lm: | |
beam_lm_scores, beam_lm_states = self.lm.batch_score( | |
self.create_lm_batch_inputs([h.yseq for h in list_exp]), | |
[h.lm_state for h in list_exp], | |
None, | |
) | |
if n < (self.nstep - 1): | |
for i, hyp in enumerate(list_exp): | |
hyp.dec_out = beam_dec_out[i] | |
hyp.dec_state = self.decoder.select_state(beam_state, i) | |
if self.use_lm: | |
hyp.lm_state = beam_lm_states[i] | |
hyp.lm_score = beam_lm_scores[i] | |
hyps = list_exp[:] | |
else: | |
beam_logp = torch.log_softmax( | |
self.joint_network(beam_enc_out, beam_dec_out), | |
dim=-1, | |
) | |
for i, hyp in enumerate(list_exp): | |
hyp.score += float(beam_logp[i, 0]) | |
hyp.dec_out = beam_dec_out[i] | |
hyp.dec_state = self.decoder.select_state(beam_state, i) | |
if self.use_lm: | |
hyp.lm_state = beam_lm_states[i] | |
hyp.lm_score = beam_lm_scores[i] | |
kept_hyps = sorted( | |
self.recombine_hyps(list_b + list_exp), | |
key=lambda x: x.score, | |
reverse=True, | |
)[: self.beam_size] | |
return kept_hyps | |