import unittest from hw_asr.text_encoder.ctc_char_text_encoder import CTCCharTextEncoder class TestTextEncoder(unittest.TestCase): def test_ctc_decode(self): text_encoder = CTCCharTextEncoder() text = "i^^ ^w^i^sss^hhh^ i ^^^s^t^aaaar^teee^d " \ "dddddd^oooo^in^g tttttttth^iiiis h^^^^^^^^w^ e^a^r^li^er" true_text = "i wish i started doing this hw earlier" inds = [text_encoder.char2ind[c] for c in text] decoded_text = text_encoder.ctc_decode(inds) self.assertIn(decoded_text, true_text) # def test_beam_search(self): # # TODO: (optional) write tests for beam search # text_encoder = CTCCharTextEncoder() # len(text_encoder.ind2char) # probs # text_encoder.ctc_beam_search()