Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023 Binbin Zhang ([email protected]) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
from collections import defaultdict | |
from typing import List, Dict | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from wenet.utils.common import (add_sos_eos, log_add, add_whisper_tokens, | |
mask_to_bias) | |
from wenet.utils.ctc_utils import remove_duplicates_and_blank | |
from wenet.utils.mask import (make_pad_mask, mask_finished_preds, | |
mask_finished_scores, subsequent_mask) | |
from wenet.utils.context_graph import ContextGraph, ContextState | |
class DecodeResult: | |
def __init__(self, | |
tokens: List[int], | |
score: float = 0.0, | |
confidence: float = 0.0, | |
tokens_confidence: List[float] = None, | |
times: List[int] = None, | |
nbest: List[List[int]] = None, | |
nbest_scores: List[float] = None, | |
nbest_times: List[List[int]] = None): | |
""" | |
Args: | |
tokens: decode token list | |
score: the total decode score of this result | |
confidence: the total confidence of this result, it's in 0~1 | |
tokens_confidence: confidence of each token | |
times: timestamp of each token, list of (start, end) | |
nbest: nbest result | |
nbest_scores: score of each nbest | |
nbest_times: | |
""" | |
self.tokens = tokens | |
self.score = score | |
self.confidence = confidence | |
self.tokens_confidence = tokens_confidence | |
self.times = times | |
self.nbest = nbest | |
self.nbest_scores = nbest_scores | |
self.nbest_times = nbest_times | |
class PrefixScore: | |
""" For CTC prefix beam search """ | |
def __init__(self, | |
s: float = float('-inf'), | |
ns: float = float('-inf'), | |
v_s: float = float('-inf'), | |
v_ns: float = float('-inf'), | |
context_state: ContextState = None, | |
context_score: float = 0.0): | |
self.s = s # blank_ending_score | |
self.ns = ns # none_blank_ending_score | |
self.v_s = v_s # viterbi blank ending score | |
self.v_ns = v_ns # viterbi none blank ending score | |
self.cur_token_prob = float('-inf') # prob of current token | |
self.times_s = [] # times of viterbi blank path | |
self.times_ns = [] # times of viterbi none blank path | |
self.context_state = context_state | |
self.context_score = context_score | |
self.has_context = False | |
def score(self): | |
return log_add(self.s, self.ns) | |
def viterbi_score(self): | |
return self.v_s if self.v_s > self.v_ns else self.v_ns | |
def times(self): | |
return self.times_s if self.v_s > self.v_ns else self.times_ns | |
def total_score(self): | |
return self.score() + self.context_score | |
def copy_context(self, prefix_score): | |
self.context_score = prefix_score.context_score | |
self.context_state = prefix_score.context_state | |
def update_context(self, context_graph, prefix_score, word_id): | |
self.copy_context(prefix_score) | |
(score, context_state) = context_graph.forward_one_step( | |
prefix_score.context_state, word_id) | |
self.context_score += score | |
self.context_state = context_state | |
def ctc_greedy_search(ctc_probs: torch.Tensor, | |
ctc_lens: torch.Tensor, | |
blank_id: int = 0) -> List[DecodeResult]: | |
batch_size = ctc_probs.shape[0] | |
maxlen = ctc_probs.size(1) | |
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) | |
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) | |
mask = make_pad_mask(ctc_lens, maxlen) # (B, maxlen) | |
topk_index = topk_index.masked_fill_(mask, blank_id) # (B, maxlen) | |
hyps = [hyp.tolist() for hyp in topk_index] | |
scores = topk_prob.max(1) | |
results = [] | |
for hyp in hyps: | |
r = DecodeResult(remove_duplicates_and_blank(hyp, blank_id)) | |
results.append(r) | |
return results | |
def ctc_prefix_beam_search( | |
ctc_probs: torch.Tensor, | |
ctc_lens: torch.Tensor, | |
beam_size: int, | |
context_graph: ContextGraph = None, | |
blank_id: int = 0, | |
) -> List[DecodeResult]: | |
""" | |
Returns: | |
List[List[List[int]]]: nbest result for each utterance | |
""" | |
batch_size = ctc_probs.shape[0] | |
results = [] | |
# CTC prefix beam search can not be paralleled, so search one by one | |
for i in range(batch_size): | |
ctc_prob = ctc_probs[i] | |
num_t = ctc_lens[i] | |
cur_hyps = [(tuple(), | |
PrefixScore(s=0.0, | |
ns=-float('inf'), | |
v_s=0.0, | |
v_ns=0.0, | |
context_state=None if context_graph is None | |
else context_graph.root, | |
context_score=0.0))] | |
# 2. CTC beam search step by step | |
for t in range(0, num_t): | |
logp = ctc_prob[t] # (vocab_size,) | |
# key: prefix, value: PrefixScore | |
next_hyps = defaultdict(lambda: PrefixScore()) | |
# 2.1 First beam prune: select topk best | |
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) | |
for u in top_k_index: | |
u = u.item() | |
prob = logp[u].item() | |
for prefix, prefix_score in cur_hyps: | |
last = prefix[-1] if len(prefix) > 0 else None | |
if u == blank_id: # blank | |
next_score = next_hyps[prefix] | |
next_score.s = log_add(next_score.s, | |
prefix_score.score() + prob) | |
next_score.v_s = prefix_score.viterbi_score() + prob | |
next_score.times_s = prefix_score.times().copy() | |
# perfix not changed, copy the context from prefix | |
if context_graph and not next_score.has_context: | |
next_score.copy_context(prefix_score) | |
next_score.has_context = True | |
elif u == last: | |
# Update *uu -> *u; | |
next_score1 = next_hyps[prefix] | |
next_score1.ns = log_add(next_score1.ns, | |
prefix_score.ns + prob) | |
if next_score1.v_ns < prefix_score.v_ns + prob: | |
next_score1.v_ns = prefix_score.v_ns + prob | |
if next_score1.cur_token_prob < prob: | |
next_score1.cur_token_prob = prob | |
next_score1.times_ns = prefix_score.times_ns.copy( | |
) | |
next_score1.times_ns[-1] = t | |
if context_graph and not next_score1.has_context: | |
next_score1.copy_context(prefix_score) | |
next_score1.has_context = True | |
# Update *u-u -> *uu, - is for blank | |
n_prefix = prefix + (u, ) | |
next_score2 = next_hyps[n_prefix] | |
next_score2.ns = log_add(next_score2.ns, | |
prefix_score.s + prob) | |
if next_score2.v_ns < prefix_score.v_s + prob: | |
next_score2.v_ns = prefix_score.v_s + prob | |
next_score2.cur_token_prob = prob | |
next_score2.times_ns = prefix_score.times_s.copy() | |
next_score2.times_ns.append(t) | |
if context_graph and not next_score2.has_context: | |
next_score2.update_context(context_graph, | |
prefix_score, u) | |
next_score2.has_context = True | |
else: | |
n_prefix = prefix + (u, ) | |
next_score = next_hyps[n_prefix] | |
next_score.ns = log_add(next_score.ns, | |
prefix_score.score() + prob) | |
if next_score.v_ns < prefix_score.viterbi_score( | |
) + prob: | |
next_score.v_ns = prefix_score.viterbi_score( | |
) + prob | |
next_score.cur_token_prob = prob | |
next_score.times_ns = prefix_score.times().copy() | |
next_score.times_ns.append(t) | |
if context_graph and not next_score.has_context: | |
next_score.update_context(context_graph, | |
prefix_score, u) | |
next_score.has_context = True | |
# 2.2 Second beam prune | |
next_hyps = sorted(next_hyps.items(), | |
key=lambda x: x[1].total_score(), | |
reverse=True) | |
cur_hyps = next_hyps[:beam_size] | |
# We should backoff the context score/state when the context is | |
# not fully matched at the last time. | |
if context_graph is not None: | |
for i, hyp in enumerate(cur_hyps): | |
context_score, new_context_state = context_graph.finalize( | |
hyp[1].context_state) | |
cur_hyps[i][1].context_score = context_score | |
cur_hyps[i][1].context_state = new_context_state | |
nbest = [y[0] for y in cur_hyps] | |
nbest_scores = [y[1].total_score() for y in cur_hyps] | |
nbest_times = [y[1].times() for y in cur_hyps] | |
best = nbest[0] | |
best_score = nbest_scores[0] | |
best_time = nbest_times[0] | |
results.append( | |
DecodeResult(tokens=best, | |
score=best_score, | |
times=best_time, | |
nbest=nbest, | |
nbest_scores=nbest_scores, | |
nbest_times=nbest_times)) | |
return results | |
def attention_beam_search( | |
model, | |
encoder_out: torch.Tensor, | |
encoder_mask: torch.Tensor, | |
beam_size: int = 10, | |
length_penalty: float = 0.0, | |
infos: Dict[str, List[str]] = None, | |
) -> List[DecodeResult]: | |
device = encoder_out.device | |
batch_size = encoder_out.shape[0] | |
# Let's assume B = batch_size and N = beam_size | |
# 1. Encoder | |
maxlen = encoder_out.size(1) | |
encoder_dim = encoder_out.size(2) | |
running_size = batch_size * beam_size | |
if getattr(model, 'special_tokens', None) is not None \ | |
and "transcribe" in model.special_tokens: | |
tasks, langs = infos["tasks"], infos["langs"] | |
tasks = [t for t in tasks for _ in range(beam_size)] | |
langs = [l for l in langs for _ in range(beam_size)] | |
hyps = torch.ones([running_size, 0], dtype=torch.long, | |
device=device) # (B*N, 0) | |
hyps, _ = add_whisper_tokens(model.special_tokens, | |
hyps, | |
model.ignore_id, | |
tasks=tasks, | |
no_timestamp=True, | |
langs=langs, | |
use_prev=False) | |
else: | |
hyps = torch.ones([running_size, 1], dtype=torch.long, | |
device=device).fill_(model.sos) # (B*N, 1) | |
prefix_len = hyps.size(1) | |
scores = torch.tensor([0.0] + [-float('inf')] * (beam_size - 1), | |
dtype=torch.float) | |
scores = scores.to(device).repeat([batch_size | |
]).unsqueeze(1).to(device) # (B*N, 1) | |
end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device) | |
cache = { | |
'self_att_cache': {}, | |
'cross_att_cache': {}, | |
} | |
if model.decoder.use_sdpa: | |
encoder_mask = mask_to_bias(encoder_mask, encoder_out.dtype) | |
if hasattr(model, 'decode_maxlen'): | |
maxlen = model.decode_maxlen | |
# 2. Decoder forward step by step | |
for i in range(prefix_len, maxlen + 1): | |
# Stop if all batch and all beam produce eos | |
if end_flag.sum() == running_size: | |
break | |
# 2.1 Forward decoder step | |
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( | |
running_size, 1, 1).to(device) # (B*N, i, i) | |
if model.decoder.use_sdpa: | |
hyps_mask = mask_to_bias(hyps_mask, encoder_out.dtype) | |
# logp: (B*N, vocab) | |
logp = model.decoder.forward_one_step(encoder_out, encoder_mask, hyps, | |
hyps_mask, cache) | |
# 2.2 First beam prune: select topk best prob at current time | |
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) | |
top_k_logp = mask_finished_scores(top_k_logp, end_flag) | |
top_k_index = mask_finished_preds(top_k_index, end_flag, model.eos) | |
# 2.3 Second beam prune: select topk score with history | |
scores = scores + top_k_logp # (B*N, N), broadcast add | |
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) | |
scores, offset_k_index = scores.topk(k=beam_size) # (B, N) | |
# Update cache to be consistent with new topk scores / hyps | |
cache_index = (offset_k_index // beam_size).view(-1) # (B*N) | |
base_cache_index = (torch.arange(batch_size, device=device).view( | |
-1, 1).repeat([1, beam_size]) * beam_size).view(-1) # (B*N) | |
cache_index = base_cache_index + cache_index | |
cache['self_att_cache'] = { | |
i_layer: (torch.index_select(value[0], dim=0, index=cache_index), | |
torch.index_select(value[1], dim=0, index=cache_index)) | |
for (i_layer, value) in cache['self_att_cache'].items() | |
} | |
# NOTE(Mddct): we don't need select cross att here | |
torch.cuda.empty_cache() | |
scores = scores.view(-1, 1) # (B*N, 1) | |
# 2.4. Compute base index in top_k_index, | |
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N), | |
# then find offset_k_index in top_k_index | |
base_k_index = torch.arange(batch_size, device=device).view( | |
-1, 1).repeat([1, beam_size]) # (B, N) | |
base_k_index = base_k_index * beam_size * beam_size | |
best_k_index = base_k_index.view(-1) + offset_k_index.view(-1) # (B*N) | |
# 2.5 Update best hyps | |
best_k_pred = torch.index_select(top_k_index.view(-1), | |
dim=-1, | |
index=best_k_index) # (B*N) | |
best_hyps_index = best_k_index // beam_size | |
last_best_k_hyps = torch.index_select( | |
hyps, dim=0, index=best_hyps_index) # (B*N, i) | |
hyps = torch.cat((last_best_k_hyps, best_k_pred.view(-1, 1)), | |
dim=1) # (B*N, i+1) | |
# 2.6 Update end flag | |
end_flag = torch.eq(hyps[:, -1], model.eos).view(-1, 1) | |
# 3. Select best of best | |
scores = scores.view(batch_size, beam_size) | |
lengths = hyps.ne(model.eos).sum(dim=1).view(batch_size, beam_size).float() | |
scores = scores / lengths.pow(length_penalty) | |
best_scores, best_index = scores.max(dim=-1) | |
best_hyps_index = best_index + torch.arange( | |
batch_size, dtype=torch.long, device=device) * beam_size | |
best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index) | |
best_hyps = best_hyps[:, prefix_len:] | |
results = [] | |
for i in range(batch_size): | |
hyp = best_hyps[i] | |
hyp = hyp[hyp != model.eos] | |
results.append(DecodeResult(hyp.tolist())) | |
return results | |
def attention_rescoring( | |
model, | |
ctc_prefix_results: List[DecodeResult], | |
encoder_outs: torch.Tensor, | |
encoder_lens: torch.Tensor, | |
ctc_weight: float = 0.0, | |
reverse_weight: float = 0.0, | |
infos: Dict[str, List[str]] = None, | |
) -> List[DecodeResult]: | |
""" | |
Args: | |
ctc_prefix_results(List[DecodeResult]): ctc prefix beam search results | |
""" | |
sos, eos = model.sos_symbol(), model.eos_symbol() | |
device = encoder_outs.device | |
assert encoder_outs.shape[0] == len(ctc_prefix_results) | |
batch_size = encoder_outs.shape[0] | |
results = [] | |
for b in range(batch_size): | |
encoder_out = encoder_outs[b, :encoder_lens[b], :].unsqueeze(0) | |
hyps = ctc_prefix_results[b].nbest | |
ctc_scores = ctc_prefix_results[b].nbest_scores | |
hyps_pad = pad_sequence([ | |
torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps | |
], True, model.ignore_id) # (beam_size, max_hyps_len) | |
hyps_lens = torch.tensor([len(hyp) for hyp in hyps], | |
device=device, | |
dtype=torch.long) # (beam_size,) | |
if getattr(model, 'special_tokens', None) is not None \ | |
and "transcribe" in model.special_tokens: | |
prev_len = hyps_pad.size(1) | |
hyps_pad, _ = add_whisper_tokens( | |
model.special_tokens, | |
hyps_pad, | |
model.ignore_id, | |
tasks=[infos["tasks"][b]] * len(hyps), | |
no_timestamp=True, | |
langs=[infos["langs"][b]] * len(hyps), | |
use_prev=False) | |
cur_len = hyps_pad.size(1) | |
hyps_lens = hyps_lens + cur_len - prev_len | |
prefix_len = 4 | |
else: | |
hyps_pad, _ = add_sos_eos(hyps_pad, sos, eos, model.ignore_id) | |
hyps_lens = hyps_lens + 1 # Add <sos> at begining | |
prefix_len = 1 | |
decoder_out, r_decoder_out = model.forward_attention_decoder( | |
hyps_pad, hyps_lens, encoder_out, reverse_weight) | |
# Only use decoder score for rescoring | |
best_score = -float('inf') | |
best_index = 0 | |
confidences = [] | |
tokens_confidences = [] | |
for i, hyp in enumerate(hyps): | |
score = 0.0 | |
tc = [] # tokens confidences | |
for j, w in enumerate(hyp): | |
s = decoder_out[i][j + (prefix_len - 1)][w] | |
score += s | |
tc.append(math.exp(s)) | |
score += decoder_out[i][len(hyp) + (prefix_len - 1)][eos] | |
# add right to left decoder score | |
if reverse_weight > 0 and r_decoder_out.dim() > 0: | |
r_score = 0.0 | |
for j, w in enumerate(hyp): | |
s = r_decoder_out[i][len(hyp) - j - 1 + | |
(prefix_len - 1)][w] | |
r_score += s | |
tc[j] = (tc[j] + math.exp(s)) / 2 | |
r_score += r_decoder_out[i][len(hyp) + (prefix_len - 1)][eos] | |
score = score * (1 - reverse_weight) + r_score * reverse_weight | |
confidences.append(math.exp(score / (len(hyp) + 1))) | |
# add ctc score | |
score += ctc_scores[i] * ctc_weight | |
if score > best_score: | |
best_score = score.item() | |
best_index = i | |
tokens_confidences.append(tc) | |
results.append( | |
DecodeResult(hyps[best_index], | |
best_score, | |
confidence=confidences[best_index], | |
times=ctc_prefix_results[b].nbest_times[best_index], | |
tokens_confidence=tokens_confidences[best_index])) | |
return results | |