|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Flashlight decoders. |
|
""" |
|
|
|
import gc |
|
import itertools as it |
|
import os.path as osp |
|
from typing import List |
|
import warnings |
|
from collections import deque, namedtuple |
|
|
|
import numpy as np |
|
import torch |
|
from examples.speech_recognition.data.replabels import unpack_replabels |
|
from fairseq import tasks |
|
from fairseq.utils import apply_to_sample |
|
from omegaconf import open_dict |
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
|
|
|
|
|
try: |
|
from flashlight.lib.text.dictionary import create_word_dict, load_words |
|
from flashlight.lib.sequence.criterion import CpuViterbiPath, get_data_ptr_as_bytes |
|
from flashlight.lib.text.decoder import ( |
|
CriterionType, |
|
LexiconDecoderOptions, |
|
KenLM, |
|
LM, |
|
LMState, |
|
SmearingMode, |
|
Trie, |
|
LexiconDecoder, |
|
) |
|
except: |
|
warnings.warn( |
|
"flashlight python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/flashlight/tree/master/bindings/python" |
|
) |
|
LM = object |
|
LMState = object |
|
|
|
|
|
class W2lDecoder(object): |
|
def __init__(self, args, tgt_dict): |
|
self.tgt_dict = tgt_dict |
|
self.vocab_size = len(tgt_dict) |
|
self.nbest = args.nbest |
|
|
|
|
|
self.criterion_type = CriterionType.CTC |
|
self.blank = ( |
|
tgt_dict.index("<ctc_blank>") |
|
if "<ctc_blank>" in tgt_dict.indices |
|
else tgt_dict.bos() |
|
) |
|
if "<sep>" in tgt_dict.indices: |
|
self.silence = tgt_dict.index("<sep>") |
|
elif "|" in tgt_dict.indices: |
|
self.silence = tgt_dict.index("|") |
|
else: |
|
self.silence = tgt_dict.eos() |
|
self.asg_transitions = None |
|
|
|
def generate(self, models, sample, **unused): |
|
"""Generate a batch of inferences.""" |
|
|
|
|
|
encoder_input = { |
|
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" |
|
} |
|
emissions = self.get_emissions(models, encoder_input) |
|
return self.decode(emissions) |
|
|
|
def get_emissions(self, models, encoder_input): |
|
"""Run encoder and normalize emissions""" |
|
model = models[0] |
|
encoder_out = model(**encoder_input) |
|
if hasattr(model, "get_logits"): |
|
emissions = model.get_logits(encoder_out) |
|
else: |
|
emissions = model.get_normalized_probs(encoder_out, log_probs=True) |
|
return emissions.transpose(0, 1).float().cpu().contiguous() |
|
|
|
def get_tokens(self, idxs): |
|
"""Normalize tokens by handling CTC blank, ASG replabels, etc.""" |
|
idxs = (g[0] for g in it.groupby(idxs)) |
|
idxs = filter(lambda x: x != self.blank, idxs) |
|
return torch.LongTensor(list(idxs)) |
|
|
|
|
|
class W2lViterbiDecoder(W2lDecoder): |
|
def __init__(self, args, tgt_dict): |
|
super().__init__(args, tgt_dict) |
|
|
|
def decode(self, emissions): |
|
B, T, N = emissions.size() |
|
hypos = [] |
|
if self.asg_transitions is None: |
|
transitions = torch.FloatTensor(N, N).zero_() |
|
else: |
|
transitions = torch.FloatTensor(self.asg_transitions).view(N, N) |
|
viterbi_path = torch.IntTensor(B, T) |
|
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N)) |
|
CpuViterbiPath.compute( |
|
B, |
|
T, |
|
N, |
|
get_data_ptr_as_bytes(emissions), |
|
get_data_ptr_as_bytes(transitions), |
|
get_data_ptr_as_bytes(viterbi_path), |
|
get_data_ptr_as_bytes(workspace), |
|
) |
|
return [ |
|
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] |
|
for b in range(B) |
|
] |
|
|
|
|
|
class W2lKenLMDecoder(W2lDecoder): |
|
def __init__(self, args, tgt_dict): |
|
super().__init__(args, tgt_dict) |
|
|
|
self.unit_lm = getattr(args, "unit_lm", False) |
|
|
|
if args.lexicon: |
|
self.lexicon = load_words(args.lexicon) |
|
self.word_dict = create_word_dict(self.lexicon) |
|
self.unk_word = self.word_dict.get_index("<unk>") |
|
|
|
self.lm = KenLM(args.kenlm_model, self.word_dict) |
|
self.trie = Trie(self.vocab_size, self.silence) |
|
|
|
start_state = self.lm.start(False) |
|
for i, (word, spellings) in enumerate(self.lexicon.items()): |
|
word_idx = self.word_dict.get_index(word) |
|
_, score = self.lm.score(start_state, word_idx) |
|
for spelling in spellings: |
|
spelling_idxs = [tgt_dict.index(token) for token in spelling] |
|
assert ( |
|
tgt_dict.unk() not in spelling_idxs |
|
), f"{spelling} {spelling_idxs}" |
|
self.trie.insert(spelling_idxs, word_idx, score) |
|
self.trie.smear(SmearingMode.MAX) |
|
|
|
self.decoder_opts = LexiconDecoderOptions( |
|
beam_size=args.beam, |
|
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), |
|
beam_threshold=args.beam_threshold, |
|
lm_weight=args.lm_weight, |
|
word_score=args.word_score, |
|
unk_score=args.unk_weight, |
|
sil_score=args.sil_weight, |
|
log_add=False, |
|
criterion_type=self.criterion_type, |
|
) |
|
|
|
if self.asg_transitions is None: |
|
N = 768 |
|
|
|
self.asg_transitions = [] |
|
|
|
self.decoder = LexiconDecoder( |
|
self.decoder_opts, |
|
self.trie, |
|
self.lm, |
|
self.silence, |
|
self.blank, |
|
self.unk_word, |
|
self.asg_transitions, |
|
self.unit_lm, |
|
) |
|
else: |
|
assert args.unit_lm, "lexicon free decoding can only be done with a unit language model" |
|
from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions |
|
|
|
d = {w: [[w]] for w in tgt_dict.symbols} |
|
self.word_dict = create_word_dict(d) |
|
self.lm = KenLM(args.kenlm_model, self.word_dict) |
|
self.decoder_opts = LexiconFreeDecoderOptions( |
|
beam_size=args.beam, |
|
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), |
|
beam_threshold=args.beam_threshold, |
|
lm_weight=args.lm_weight, |
|
sil_score=args.sil_weight, |
|
log_add=False, |
|
criterion_type=self.criterion_type, |
|
) |
|
self.decoder = LexiconFreeDecoder( |
|
self.decoder_opts, self.lm, self.silence, self.blank, [] |
|
) |
|
|
|
def get_timesteps(self, token_idxs: List[int]) -> List[int]: |
|
"""Returns frame numbers corresponding to every non-blank token. |
|
|
|
Parameters |
|
---------- |
|
token_idxs : List[int] |
|
IDs of decoded tokens. |
|
|
|
Returns |
|
------- |
|
List[int] |
|
Frame numbers corresponding to every non-blank token. |
|
""" |
|
timesteps = [] |
|
for i, token_idx in enumerate(token_idxs): |
|
if token_idx == self.blank: |
|
continue |
|
if i == 0 or token_idx != token_idxs[i-1]: |
|
timesteps.append(i) |
|
return timesteps |
|
|
|
def decode(self, emissions): |
|
B, T, N = emissions.size() |
|
hypos = [] |
|
for b in range(B): |
|
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) |
|
results = self.decoder.decode(emissions_ptr, T, N) |
|
|
|
nbest_results = results[: self.nbest] |
|
hypos.append( |
|
[ |
|
{ |
|
"tokens": self.get_tokens(result.tokens), |
|
"score": result.score, |
|
"timesteps": self.get_timesteps(result.tokens), |
|
"words": [ |
|
self.word_dict.get_entry(x) for x in result.words if x >= 0 |
|
], |
|
} |
|
for result in nbest_results |
|
] |
|
) |
|
return hypos |
|
|
|
|
|
FairseqLMState = namedtuple("FairseqLMState", ["prefix", "incremental_state", "probs"]) |
|
|
|
|
|
class FairseqLM(LM): |
|
def __init__(self, dictionary, model): |
|
LM.__init__(self) |
|
self.dictionary = dictionary |
|
self.model = model |
|
self.unk = self.dictionary.unk() |
|
|
|
self.save_incremental = False |
|
self.max_cache = 20_000 |
|
|
|
model.cuda() |
|
model.eval() |
|
model.make_generation_fast_() |
|
|
|
self.states = {} |
|
self.stateq = deque() |
|
|
|
def start(self, start_with_nothing): |
|
state = LMState() |
|
prefix = torch.LongTensor([[self.dictionary.eos()]]) |
|
incremental_state = {} if self.save_incremental else None |
|
with torch.no_grad(): |
|
res = self.model(prefix.cuda(), incremental_state=incremental_state) |
|
probs = self.model.get_normalized_probs(res, log_probs=True, sample=None) |
|
|
|
if incremental_state is not None: |
|
incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state) |
|
self.states[state] = FairseqLMState( |
|
prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy() |
|
) |
|
self.stateq.append(state) |
|
|
|
return state |
|
|
|
def score(self, state: LMState, token_index: int, no_cache: bool = False): |
|
""" |
|
Evaluate language model based on the current lm state and new word |
|
Parameters: |
|
----------- |
|
state: current lm state |
|
token_index: index of the word |
|
(can be lexicon index then you should store inside LM the |
|
mapping between indices of lexicon and lm, or lm index of a word) |
|
|
|
Returns: |
|
-------- |
|
(LMState, float): pair of (new state, score for the current word) |
|
""" |
|
curr_state = self.states[state] |
|
|
|
def trim_cache(targ_size): |
|
while len(self.stateq) > targ_size: |
|
rem_k = self.stateq.popleft() |
|
rem_st = self.states[rem_k] |
|
rem_st = FairseqLMState(rem_st.prefix, None, None) |
|
self.states[rem_k] = rem_st |
|
|
|
if curr_state.probs is None: |
|
new_incremental_state = ( |
|
curr_state.incremental_state.copy() |
|
if curr_state.incremental_state is not None |
|
else None |
|
) |
|
with torch.no_grad(): |
|
if new_incremental_state is not None: |
|
new_incremental_state = apply_to_sample( |
|
lambda x: x.cuda(), new_incremental_state |
|
) |
|
elif self.save_incremental: |
|
new_incremental_state = {} |
|
|
|
res = self.model( |
|
torch.from_numpy(curr_state.prefix).cuda(), |
|
incremental_state=new_incremental_state, |
|
) |
|
probs = self.model.get_normalized_probs( |
|
res, log_probs=True, sample=None |
|
) |
|
|
|
if new_incremental_state is not None: |
|
new_incremental_state = apply_to_sample( |
|
lambda x: x.cpu(), new_incremental_state |
|
) |
|
|
|
curr_state = FairseqLMState( |
|
curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy() |
|
) |
|
|
|
if not no_cache: |
|
self.states[state] = curr_state |
|
self.stateq.append(state) |
|
|
|
score = curr_state.probs[token_index].item() |
|
|
|
trim_cache(self.max_cache) |
|
|
|
outstate = state.child(token_index) |
|
if outstate not in self.states and not no_cache: |
|
prefix = np.concatenate( |
|
[curr_state.prefix, torch.LongTensor([[token_index]])], -1 |
|
) |
|
incr_state = curr_state.incremental_state |
|
|
|
self.states[outstate] = FairseqLMState(prefix, incr_state, None) |
|
|
|
if token_index == self.unk: |
|
score = float("-inf") |
|
|
|
return outstate, score |
|
|
|
def finish(self, state: LMState): |
|
""" |
|
Evaluate eos for language model based on the current lm state |
|
|
|
Returns: |
|
-------- |
|
(LMState, float): pair of (new state, score for the current word) |
|
""" |
|
return self.score(state, self.dictionary.eos()) |
|
|
|
def empty_cache(self): |
|
self.states = {} |
|
self.stateq = deque() |
|
gc.collect() |
|
|
|
|
|
class W2lFairseqLMDecoder(W2lDecoder): |
|
def __init__(self, args, tgt_dict): |
|
super().__init__(args, tgt_dict) |
|
|
|
self.unit_lm = getattr(args, "unit_lm", False) |
|
|
|
self.lexicon = load_words(args.lexicon) if args.lexicon else None |
|
self.idx_to_wrd = {} |
|
|
|
checkpoint = torch.load(args.kenlm_model, map_location="cpu") |
|
|
|
if "cfg" in checkpoint and checkpoint["cfg"] is not None: |
|
lm_args = checkpoint["cfg"] |
|
else: |
|
lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) |
|
|
|
with open_dict(lm_args.task): |
|
lm_args.task.data = osp.dirname(args.kenlm_model) |
|
|
|
task = tasks.setup_task(lm_args.task) |
|
model = task.build_model(lm_args.model) |
|
model.load_state_dict(checkpoint["model"], strict=False) |
|
|
|
self.trie = Trie(self.vocab_size, self.silence) |
|
|
|
self.word_dict = task.dictionary |
|
self.unk_word = self.word_dict.unk() |
|
self.lm = FairseqLM(self.word_dict, model) |
|
|
|
if self.lexicon: |
|
start_state = self.lm.start(False) |
|
for i, (word, spellings) in enumerate(self.lexicon.items()): |
|
if self.unit_lm: |
|
word_idx = i |
|
self.idx_to_wrd[i] = word |
|
score = 0 |
|
else: |
|
word_idx = self.word_dict.index(word) |
|
_, score = self.lm.score(start_state, word_idx, no_cache=True) |
|
|
|
for spelling in spellings: |
|
spelling_idxs = [tgt_dict.index(token) for token in spelling] |
|
assert ( |
|
tgt_dict.unk() not in spelling_idxs |
|
), f"{spelling} {spelling_idxs}" |
|
self.trie.insert(spelling_idxs, word_idx, score) |
|
self.trie.smear(SmearingMode.MAX) |
|
|
|
self.decoder_opts = LexiconDecoderOptions( |
|
beam_size=args.beam, |
|
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), |
|
beam_threshold=args.beam_threshold, |
|
lm_weight=args.lm_weight, |
|
word_score=args.word_score, |
|
unk_score=args.unk_weight, |
|
sil_score=args.sil_weight, |
|
log_add=False, |
|
criterion_type=self.criterion_type, |
|
) |
|
|
|
self.decoder = LexiconDecoder( |
|
self.decoder_opts, |
|
self.trie, |
|
self.lm, |
|
self.silence, |
|
self.blank, |
|
self.unk_word, |
|
[], |
|
self.unit_lm, |
|
) |
|
else: |
|
assert args.unit_lm, "lexicon free decoding can only be done with a unit language model" |
|
from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions |
|
|
|
d = {w: [[w]] for w in tgt_dict.symbols} |
|
self.word_dict = create_word_dict(d) |
|
self.lm = KenLM(args.kenlm_model, self.word_dict) |
|
self.decoder_opts = LexiconFreeDecoderOptions( |
|
beam_size=args.beam, |
|
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), |
|
beam_threshold=args.beam_threshold, |
|
lm_weight=args.lm_weight, |
|
sil_score=args.sil_weight, |
|
log_add=False, |
|
criterion_type=self.criterion_type, |
|
) |
|
self.decoder = LexiconFreeDecoder( |
|
self.decoder_opts, self.lm, self.silence, self.blank, [] |
|
) |
|
|
|
def decode(self, emissions): |
|
B, T, N = emissions.size() |
|
hypos = [] |
|
|
|
def idx_to_word(idx): |
|
if self.unit_lm: |
|
return self.idx_to_wrd[idx] |
|
else: |
|
return self.word_dict[idx] |
|
|
|
def make_hypo(result): |
|
hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score} |
|
if self.lexicon: |
|
hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0] |
|
return hypo |
|
|
|
for b in range(B): |
|
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) |
|
results = self.decoder.decode(emissions_ptr, T, N) |
|
|
|
nbest_results = results[: self.nbest] |
|
hypos.append([make_hypo(result) for result in nbest_results]) |
|
self.lm.empty_cache() |
|
|
|
return hypos |
|
|