|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Deep speech decoder.""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import itertools |
|
|
|
from nltk.metrics import distance |
|
import numpy as np |
|
|
|
|
|
class DeepSpeechDecoder(object): |
|
"""Greedy decoder implementation for Deep Speech model.""" |
|
|
|
def __init__(self, labels, blank_index=28): |
|
"""Decoder initialization. |
|
|
|
Arguments: |
|
labels: a string specifying the speech labels for the decoder to use. |
|
blank_index: an integer specifying index for the blank character. |
|
Defaults to 28. |
|
""" |
|
|
|
self.labels = labels |
|
self.blank_index = blank_index |
|
self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)]) |
|
|
|
def convert_to_string(self, sequence): |
|
"""Convert a sequence of indexes into corresponding string.""" |
|
return ''.join([self.int_to_char[i] for i in sequence]) |
|
|
|
def wer(self, decode, target): |
|
"""Computes the Word Error Rate (WER). |
|
|
|
WER is defined as the edit distance between the two provided sentences after |
|
tokenizing to words. |
|
|
|
Args: |
|
decode: string of the decoded output. |
|
target: a string for the ground truth label. |
|
|
|
Returns: |
|
A float number for the WER of the current decode-target pair. |
|
""" |
|
|
|
words = set(decode.split() + target.split()) |
|
word2char = dict(zip(words, range(len(words)))) |
|
|
|
new_decode = [chr(word2char[w]) for w in decode.split()] |
|
new_target = [chr(word2char[w]) for w in target.split()] |
|
|
|
return distance.edit_distance(''.join(new_decode), ''.join(new_target)) |
|
|
|
def cer(self, decode, target): |
|
"""Computes the Character Error Rate (CER). |
|
|
|
CER is defined as the edit distance between the two given strings. |
|
|
|
Args: |
|
decode: a string of the decoded output. |
|
target: a string for the ground truth label. |
|
|
|
Returns: |
|
A float number denoting the CER for the current sentence pair. |
|
""" |
|
return distance.edit_distance(decode, target) |
|
|
|
def decode(self, logits): |
|
"""Decode the best guess from logits using greedy algorithm.""" |
|
|
|
best = list(np.argmax(logits, axis=1)) |
|
|
|
merge = [k for k, _ in itertools.groupby(best)] |
|
|
|
merge_remove_blank = [] |
|
for k in merge: |
|
if k != self.blank_index: |
|
merge_remove_blank.append(k) |
|
|
|
return self.convert_to_string(merge_remove_blank) |
|
|