# Copyright (c) 2023 Binbin Zhang (binbzha@qq.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from collections import defaultdict from typing import List, Dict import torch from torch.nn.utils.rnn import pad_sequence from wenet.utils.common import (add_sos_eos, log_add, add_whisper_tokens, mask_to_bias) from wenet.utils.ctc_utils import remove_duplicates_and_blank from wenet.utils.mask import (make_pad_mask, mask_finished_preds, mask_finished_scores, subsequent_mask) from wenet.utils.context_graph import ContextGraph, ContextState class DecodeResult: def __init__(self, tokens: List[int], score: float = 0.0, confidence: float = 0.0, tokens_confidence: List[float] = None, times: List[int] = None, nbest: List[List[int]] = None, nbest_scores: List[float] = None, nbest_times: List[List[int]] = None): """ Args: tokens: decode token list score: the total decode score of this result confidence: the total confidence of this result, it's in 0~1 tokens_confidence: confidence of each token times: timestamp of each token, list of (start, end) nbest: nbest result nbest_scores: score of each nbest nbest_times: """ self.tokens = tokens self.score = score self.confidence = confidence self.tokens_confidence = tokens_confidence self.times = times self.nbest = nbest self.nbest_scores = nbest_scores self.nbest_times = nbest_times class PrefixScore: """ For CTC prefix beam search """ def __init__(self, s: float = float('-inf'), ns: float = float('-inf'), v_s: float = float('-inf'), v_ns: float = float('-inf'), context_state: ContextState = None, context_score: float = 0.0): self.s = s # blank_ending_score self.ns = ns # none_blank_ending_score self.v_s = v_s # viterbi blank ending score self.v_ns = v_ns # viterbi none blank ending score self.cur_token_prob = float('-inf') # prob of current token self.times_s = [] # times of viterbi blank path self.times_ns = [] # times of viterbi none blank path self.context_state = context_state self.context_score = context_score self.has_context = False def score(self): return log_add(self.s, self.ns) def viterbi_score(self): return self.v_s if self.v_s > self.v_ns else self.v_ns def times(self): return self.times_s if self.v_s > self.v_ns else self.times_ns def total_score(self): return self.score() + self.context_score def copy_context(self, prefix_score): self.context_score = prefix_score.context_score self.context_state = prefix_score.context_state def update_context(self, context_graph, prefix_score, word_id): self.copy_context(prefix_score) (score, context_state) = context_graph.forward_one_step( prefix_score.context_state, word_id) self.context_score += score self.context_state = context_state def ctc_greedy_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor, blank_id: int = 0) -> List[DecodeResult]: batch_size = ctc_probs.shape[0] maxlen = ctc_probs.size(1) topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) mask = make_pad_mask(ctc_lens, maxlen) # (B, maxlen) topk_index = topk_index.masked_fill_(mask, blank_id) # (B, maxlen) hyps = [hyp.tolist() for hyp in topk_index] scores = topk_prob.max(1) results = [] for hyp in hyps: r = DecodeResult(remove_duplicates_and_blank(hyp, blank_id)) results.append(r) return results def ctc_prefix_beam_search( ctc_probs: torch.Tensor, ctc_lens: torch.Tensor, beam_size: int, context_graph: ContextGraph = None, blank_id: int = 0, ) -> List[DecodeResult]: """ Returns: List[List[List[int]]]: nbest result for each utterance """ batch_size = ctc_probs.shape[0] results = [] # CTC prefix beam search can not be paralleled, so search one by one for i in range(batch_size): ctc_prob = ctc_probs[i] num_t = ctc_lens[i] cur_hyps = [(tuple(), PrefixScore(s=0.0, ns=-float('inf'), v_s=0.0, v_ns=0.0, context_state=None if context_graph is None else context_graph.root, context_score=0.0))] # 2. CTC beam search step by step for t in range(0, num_t): logp = ctc_prob[t] # (vocab_size,) # key: prefix, value: PrefixScore next_hyps = defaultdict(lambda: PrefixScore()) # 2.1 First beam prune: select topk best top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) for u in top_k_index: u = u.item() prob = logp[u].item() for prefix, prefix_score in cur_hyps: last = prefix[-1] if len(prefix) > 0 else None if u == blank_id: # blank next_score = next_hyps[prefix] next_score.s = log_add(next_score.s, prefix_score.score() + prob) next_score.v_s = prefix_score.viterbi_score() + prob next_score.times_s = prefix_score.times().copy() # perfix not changed, copy the context from prefix if context_graph and not next_score.has_context: next_score.copy_context(prefix_score) next_score.has_context = True elif u == last: # Update *uu -> *u; next_score1 = next_hyps[prefix] next_score1.ns = log_add(next_score1.ns, prefix_score.ns + prob) if next_score1.v_ns < prefix_score.v_ns + prob: next_score1.v_ns = prefix_score.v_ns + prob if next_score1.cur_token_prob < prob: next_score1.cur_token_prob = prob next_score1.times_ns = prefix_score.times_ns.copy( ) next_score1.times_ns[-1] = t if context_graph and not next_score1.has_context: next_score1.copy_context(prefix_score) next_score1.has_context = True # Update *u-u -> *uu, - is for blank n_prefix = prefix + (u, ) next_score2 = next_hyps[n_prefix] next_score2.ns = log_add(next_score2.ns, prefix_score.s + prob) if next_score2.v_ns < prefix_score.v_s + prob: next_score2.v_ns = prefix_score.v_s + prob next_score2.cur_token_prob = prob next_score2.times_ns = prefix_score.times_s.copy() next_score2.times_ns.append(t) if context_graph and not next_score2.has_context: next_score2.update_context(context_graph, prefix_score, u) next_score2.has_context = True else: n_prefix = prefix + (u, ) next_score = next_hyps[n_prefix] next_score.ns = log_add(next_score.ns, prefix_score.score() + prob) if next_score.v_ns < prefix_score.viterbi_score( ) + prob: next_score.v_ns = prefix_score.viterbi_score( ) + prob next_score.cur_token_prob = prob next_score.times_ns = prefix_score.times().copy() next_score.times_ns.append(t) if context_graph and not next_score.has_context: next_score.update_context(context_graph, prefix_score, u) next_score.has_context = True # 2.2 Second beam prune next_hyps = sorted(next_hyps.items(), key=lambda x: x[1].total_score(), reverse=True) cur_hyps = next_hyps[:beam_size] # We should backoff the context score/state when the context is # not fully matched at the last time. if context_graph is not None: for i, hyp in enumerate(cur_hyps): context_score, new_context_state = context_graph.finalize( hyp[1].context_state) cur_hyps[i][1].context_score = context_score cur_hyps[i][1].context_state = new_context_state nbest = [y[0] for y in cur_hyps] nbest_scores = [y[1].total_score() for y in cur_hyps] nbest_times = [y[1].times() for y in cur_hyps] best = nbest[0] best_score = nbest_scores[0] best_time = nbest_times[0] results.append( DecodeResult(tokens=best, score=best_score, times=best_time, nbest=nbest, nbest_scores=nbest_scores, nbest_times=nbest_times)) return results def attention_beam_search( model, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, beam_size: int = 10, length_penalty: float = 0.0, infos: Dict[str, List[str]] = None, ) -> List[DecodeResult]: device = encoder_out.device batch_size = encoder_out.shape[0] # Let's assume B = batch_size and N = beam_size # 1. Encoder maxlen = encoder_out.size(1) encoder_dim = encoder_out.size(2) running_size = batch_size * beam_size if getattr(model, 'special_tokens', None) is not None \ and "transcribe" in model.special_tokens: tasks, langs = infos["tasks"], infos["langs"] tasks = [t for t in tasks for _ in range(beam_size)] langs = [l for l in langs for _ in range(beam_size)] hyps = torch.ones([running_size, 0], dtype=torch.long, device=device) # (B*N, 0) hyps, _ = add_whisper_tokens(model.special_tokens, hyps, model.ignore_id, tasks=tasks, no_timestamp=True, langs=langs, use_prev=False) else: hyps = torch.ones([running_size, 1], dtype=torch.long, device=device).fill_(model.sos) # (B*N, 1) prefix_len = hyps.size(1) scores = torch.tensor([0.0] + [-float('inf')] * (beam_size - 1), dtype=torch.float) scores = scores.to(device).repeat([batch_size ]).unsqueeze(1).to(device) # (B*N, 1) end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device) cache = { 'self_att_cache': {}, 'cross_att_cache': {}, } if model.decoder.use_sdpa: encoder_mask = mask_to_bias(encoder_mask, encoder_out.dtype) if hasattr(model, 'decode_maxlen'): maxlen = model.decode_maxlen # 2. Decoder forward step by step for i in range(prefix_len, maxlen + 1): # Stop if all batch and all beam produce eos if end_flag.sum() == running_size: break # 2.1 Forward decoder step hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( running_size, 1, 1).to(device) # (B*N, i, i) if model.decoder.use_sdpa: hyps_mask = mask_to_bias(hyps_mask, encoder_out.dtype) # logp: (B*N, vocab) logp = model.decoder.forward_one_step(encoder_out, encoder_mask, hyps, hyps_mask, cache) # 2.2 First beam prune: select topk best prob at current time top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) top_k_logp = mask_finished_scores(top_k_logp, end_flag) top_k_index = mask_finished_preds(top_k_index, end_flag, model.eos) # 2.3 Second beam prune: select topk score with history scores = scores + top_k_logp # (B*N, N), broadcast add scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) scores, offset_k_index = scores.topk(k=beam_size) # (B, N) # Update cache to be consistent with new topk scores / hyps cache_index = (offset_k_index // beam_size).view(-1) # (B*N) base_cache_index = (torch.arange(batch_size, device=device).view( -1, 1).repeat([1, beam_size]) * beam_size).view(-1) # (B*N) cache_index = base_cache_index + cache_index cache['self_att_cache'] = { i_layer: (torch.index_select(value[0], dim=0, index=cache_index), torch.index_select(value[1], dim=0, index=cache_index)) for (i_layer, value) in cache['self_att_cache'].items() } # NOTE(Mddct): we don't need select cross att here torch.cuda.empty_cache() scores = scores.view(-1, 1) # (B*N, 1) # 2.4. Compute base index in top_k_index, # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), # then find offset_k_index in top_k_index base_k_index = torch.arange(batch_size, device=device).view( -1, 1).repeat([1, beam_size]) # (B, N) base_k_index = base_k_index * beam_size * beam_size best_k_index = base_k_index.view(-1) + offset_k_index.view(-1) # (B*N) # 2.5 Update best hyps best_k_pred = torch.index_select(top_k_index.view(-1), dim=-1, index=best_k_index) # (B*N) best_hyps_index = best_k_index // beam_size last_best_k_hyps = torch.index_select( hyps, dim=0, index=best_hyps_index) # (B*N, i) hyps = torch.cat((last_best_k_hyps, best_k_pred.view(-1, 1)), dim=1) # (B*N, i+1) # 2.6 Update end flag end_flag = torch.eq(hyps[:, -1], model.eos).view(-1, 1) # 3. Select best of best scores = scores.view(batch_size, beam_size) lengths = hyps.ne(model.eos).sum(dim=1).view(batch_size, beam_size).float() scores = scores / lengths.pow(length_penalty) best_scores, best_index = scores.max(dim=-1) best_hyps_index = best_index + torch.arange( batch_size, dtype=torch.long, device=device) * beam_size best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index) best_hyps = best_hyps[:, prefix_len:] results = [] for i in range(batch_size): hyp = best_hyps[i] hyp = hyp[hyp != model.eos] results.append(DecodeResult(hyp.tolist())) return results def attention_rescoring( model, ctc_prefix_results: List[DecodeResult], encoder_outs: torch.Tensor, encoder_lens: torch.Tensor, ctc_weight: float = 0.0, reverse_weight: float = 0.0, infos: Dict[str, List[str]] = None, ) -> List[DecodeResult]: """ Args: ctc_prefix_results(List[DecodeResult]): ctc prefix beam search results """ sos, eos = model.sos_symbol(), model.eos_symbol() device = encoder_outs.device assert encoder_outs.shape[0] == len(ctc_prefix_results) batch_size = encoder_outs.shape[0] results = [] for b in range(batch_size): encoder_out = encoder_outs[b, :encoder_lens[b], :].unsqueeze(0) hyps = ctc_prefix_results[b].nbest ctc_scores = ctc_prefix_results[b].nbest_scores hyps_pad = pad_sequence([ torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps ], True, model.ignore_id) # (beam_size, max_hyps_len) hyps_lens = torch.tensor([len(hyp) for hyp in hyps], device=device, dtype=torch.long) # (beam_size,) if getattr(model, 'special_tokens', None) is not None \ and "transcribe" in model.special_tokens: prev_len = hyps_pad.size(1) hyps_pad, _ = add_whisper_tokens( model.special_tokens, hyps_pad, model.ignore_id, tasks=[infos["tasks"][b]] * len(hyps), no_timestamp=True, langs=[infos["langs"][b]] * len(hyps), use_prev=False) cur_len = hyps_pad.size(1) hyps_lens = hyps_lens + cur_len - prev_len prefix_len = 4 else: hyps_pad, _ = add_sos_eos(hyps_pad, sos, eos, model.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining prefix_len = 1 decoder_out, r_decoder_out = model.forward_attention_decoder( hyps_pad, hyps_lens, encoder_out, reverse_weight) # Only use decoder score for rescoring best_score = -float('inf') best_index = 0 confidences = [] tokens_confidences = [] for i, hyp in enumerate(hyps): score = 0.0 tc = [] # tokens confidences for j, w in enumerate(hyp): s = decoder_out[i][j + (prefix_len - 1)][w] score += s tc.append(math.exp(s)) score += decoder_out[i][len(hyp) + (prefix_len - 1)][eos] # add right to left decoder score if reverse_weight > 0 and r_decoder_out.dim() > 0: r_score = 0.0 for j, w in enumerate(hyp): s = r_decoder_out[i][len(hyp) - j - 1 + (prefix_len - 1)][w] r_score += s tc[j] = (tc[j] + math.exp(s)) / 2 r_score += r_decoder_out[i][len(hyp) + (prefix_len - 1)][eos] score = score * (1 - reverse_weight) + r_score * reverse_weight confidences.append(math.exp(score / (len(hyp) + 1))) # add ctc score score += ctc_scores[i] * ctc_weight if score > best_score: best_score = score.item() best_index = i tokens_confidences.append(tc) results.append( DecodeResult(hyps[best_index], best_score, confidence=confidences[best_index], times=ctc_prefix_results[b].nbest_times[best_index], tokens_confidence=tokens_confidences[best_index])) return results