import unittest | |
from hw_asr.text_encoder.ctc_char_text_encoder import CTCCharTextEncoder | |
class TestTextEncoder(unittest.TestCase): | |
def test_ctc_decode(self): | |
text_encoder = CTCCharTextEncoder() | |
text = "i^^ ^w^i^sss^hhh^ i ^^^s^t^aaaar^teee^d " \ | |
"dddddd^oooo^in^g tttttttth^iiiis h^^^^^^^^w^ e^a^r^li^er" | |
true_text = "i wish i started doing this hw earlier" | |
inds = [text_encoder.char2ind[c] for c in text] | |
decoded_text = text_encoder.ctc_decode(inds) | |
self.assertIn(decoded_text, true_text) | |
# def test_beam_search(self): | |
# # TODO: (optional) write tests for beam search | |
# text_encoder = CTCCharTextEncoder() | |
# len(text_encoder.ind2char) | |
# probs | |
# text_encoder.ctc_beam_search() | |