import unittest import torch from training.preprocess.wav2vec_aligner import Wav2VecAligner class TestWav2VecAligner(unittest.TestCase): def setUp(self): self.model = Wav2VecAligner() self.text = "I HAD THAT CURIOSITY BESIDE ME AT THIS MOMENT" self.wav_path = "./mocks/audio_example.wav" def test_load_audio(self): _, sample_rate = self.model.load_audio(self.wav_path) self.assertEqual(sample_rate, 16_000) with self.assertRaises(FileNotFoundError): self.model.load_audio("./nonexistent/path.wav") def test_encode(self): tokens = self.model.encode(self.text) torch.testing.assert_close( tokens, torch.tensor( [ [ 10, 4, 11, 7, 14, 4, 6, 11, 7, 6, 4, 19, 16, 13, 10, 8, 12, 10, 6, 22, 4, 24, 5, 12, 10, 14, 5, 4, 17, 5, 4, 7, 6, 4, 6, 11, 10, 12, 4, 17, 8, 17, 5, 9, 6, ], ], ), ) def test_decode(self): transcript = self.model.decode( [ [ 10, 4, 11, 7, 14, 4, 6, 11, 7, 6, 4, 19, 16, 13, 10, 8, 12, 10, 6, 22, 4, 24, 5, 12, 10, 14, 5, 4, 17, 5, 4, 7, 6, 4, 6, 11, 10, 12, 4, 17, 8, 17, 5, 9, 6, ], ], ) self.assertEqual(transcript, self.text) def test_align_single_sample(self): audio_input, _ = self.model.load_audio(self.wav_path) emissions, tokens, transcript = self.model.align_single_sample( audio_input, self.text, ) self.assertEqual(emissions.shape, torch.Size([169, 32])) self.assertEqual( len(tokens), 47, ) self.assertEqual(transcript, "|I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|") def test_get_trellis(self): audio_input, _ = self.model.load_audio(self.wav_path) emissions, tokens, _ = self.model.align_single_sample(audio_input, self.text) trellis = self.model.get_trellis(emissions, tokens) self.assertEqual(emissions.shape, torch.Size([169, 32])) self.assertEqual(len(tokens), 47) # Add assertions here based on the expected behavior of get_trellis self.assertIsInstance(trellis, torch.Tensor) self.assertEqual(trellis.shape, torch.Size([169, 47])) def test_backtrack(self): audio_input, _ = self.model.load_audio(self.wav_path) emissions, tokens, _ = self.model.align_single_sample(audio_input, self.text) trellis = self.model.get_trellis(emissions, tokens) path = self.model.backtrack(trellis, emissions, tokens) # Add assertions here based on the expected behavior of backtrack self.assertIsInstance(path, list) self.assertEqual(len(path), 169) def test_merge_repeats(self): audio_input, _ = self.model.load_audio(self.wav_path) emissions, tokens, transcript = self.model.align_single_sample( audio_input, self.text, ) trellis = self.model.get_trellis(emissions, tokens) path = self.model.backtrack(trellis, emissions, tokens) merged_path = self.model.merge_repeats(path, transcript) # Add assertions here based on the expected behavior of merge_repeats self.assertIsInstance(merged_path, list) self.assertEqual(len(merged_path), 47) def test_merge_words(self): audio_input, _ = self.model.load_audio(self.wav_path) emissions, tokens, transcript = self.model.align_single_sample( audio_input, self.text, ) trellis = self.model.get_trellis(emissions, tokens) path = self.model.backtrack(trellis, emissions, tokens) merged_path = self.model.merge_repeats(path, transcript) merged_words = self.model.merge_words(merged_path) # Add assertions here based on the expected behavior of merge_words self.assertIsInstance(merged_words, list) self.assertEqual(len(merged_words), 9) def test_forward(self): result = self.model(self.wav_path, self.text) # self.assertEqual(result, expected_result) self.assertEqual(len(result), 9) def test_save_segments(self): # self.model.save_segments(self.wav_path, self.text, "./mocks/wav2vec_aligner/audio") self.assertEqual(True, True) if __name__ == "__main__": unittest.main()