tgritsaev's picture
Upload 198 files
affcd23 verified
raw
history blame
795 Bytes
import unittest
from hw_asr.text_encoder.ctc_char_text_encoder import CTCCharTextEncoder
class TestTextEncoder(unittest.TestCase):
def test_ctc_decode(self):
text_encoder = CTCCharTextEncoder()
text = "i^^ ^w^i^sss^hhh^ i ^^^s^t^aaaar^teee^d " \
"dddddd^oooo^in^g tttttttth^iiiis h^^^^^^^^w^ e^a^r^li^er"
true_text = "i wish i started doing this hw earlier"
inds = [text_encoder.char2ind[c] for c in text]
decoded_text = text_encoder.ctc_decode(inds)
self.assertIn(decoded_text, true_text)
# def test_beam_search(self):
# # TODO: (optional) write tests for beam search
# text_encoder = CTCCharTextEncoder()
# len(text_encoder.ind2char)
# probs
# text_encoder.ctc_beam_search()