alessandro trinca tornidor commited on
Commit
d1b2b5d
·
1 Parent(s): 4cafb0a

test: update test cases for pronunciationTrainer module

Browse files
aip_trainer/utils/split_cosmic_ray_report.py CHANGED
@@ -25,5 +25,5 @@ def get_cosmic_ray_report_filtered(input_filename, suffix="filtered", separator=
25
 
26
  if __name__ == "__main__":
27
  from aip_trainer import PROJECT_ROOT_FOLDER
28
- _input_filename = "cosmic-ray-lambdagetsample4.txt"
29
  get_cosmic_ray_report_filtered(PROJECT_ROOT_FOLDER / "tmp" / _input_filename)
 
25
 
26
  if __name__ == "__main__":
27
  from aip_trainer import PROJECT_ROOT_FOLDER
28
+ _input_filename = "cosmic-ray-pronunciationtrainer1.txt"
29
  get_cosmic_ray_report_filtered(PROJECT_ROOT_FOLDER / "tmp" / _input_filename)
cosmic_ray_config.toml CHANGED
@@ -1,8 +1,8 @@
1
  [cosmic-ray]
2
- module-path = "aip_trainer/WordMatching.py"
3
  timeout = 30.0
4
  excluded-modules = []
5
- test-command = "python -m pytest tests/test_worldmatching.py"
6
 
7
  [cosmic-ray.distributor]
8
  name = "local"
 
1
  [cosmic-ray]
2
+ module-path = "aip_trainer/pronunciationTrainer.py"
3
  timeout = 30.0
4
  excluded-modules = []
5
+ test-command = "python -m pytest tests/test_pronunciationtrainer.py"
6
 
7
  [cosmic-ray.distributor]
8
  name = "local"
tests/lambdas/test_lambdaSpeechToScore.py CHANGED
@@ -20,9 +20,7 @@ def set_seed(seed=0):
20
  torch.manual_seed(seed)
21
 
22
 
23
- def assert_raises_get_speech_to_score_dict(
24
- cls, real_text, file_bytes_or_audiotmpfile, language, exc, error_message
25
- ):
26
  from aip_trainer.lambdas import lambdaSpeechToScore
27
 
28
  with cls.assertRaises(exc):
 
20
  torch.manual_seed(seed)
21
 
22
 
23
+ def assert_raises_get_speech_to_score_dict(cls, real_text, file_bytes_or_audiotmpfile, language, exc, error_message):
 
 
24
  from aip_trainer.lambdas import lambdaSpeechToScore
25
 
26
  with cls.assertRaises(exc):
tests/test_pronunciationtrainer.py CHANGED
@@ -40,64 +40,90 @@ class TestScore(unittest.TestCase):
40
  def test_exact_transcription_de(self):
41
  set_seed()
42
  phrase_real = phrases["de"]["real"]
43
- real_and_transcribed_words, _, _ = trainer_SST_lambda_de.matchSampleAndRecordedWords(phrase_real, phrase_real)
44
- pronunciation_accuracy, _ = trainer_SST_lambda_de.getPronunciationAccuracy(real_and_transcribed_words)
 
 
45
  self.assertEqual(int(pronunciation_accuracy), 100)
 
46
 
47
  def test_transcription_de(self):
48
  set_seed()
49
  phrase_real = phrases["de"]["real"]
50
  phrase_transcribed = phrases["de"]["transcribed"]
51
- real_and_transcribed_words, _, _ = trainer_SST_lambda_de.matchSampleAndRecordedWords(phrase_real, phrase_transcribed)
52
- pronunciation_accuracy, _ = trainer_SST_lambda_de.getPronunciationAccuracy(real_and_transcribed_words)
 
 
53
  self.assertEqual(int(pronunciation_accuracy), 100)
 
54
 
55
  def test_partial_transcription_de(self):
56
  set_seed()
57
  phrase_real = phrases["de"]["real"]
58
  phrase_partial = phrases["de"]["partial"]
59
- real_and_transcribed_words, _, _ = trainer_SST_lambda_de.matchSampleAndRecordedWords(phrase_real, phrase_partial)
60
- pronunciation_accuracy, _ = trainer_SST_lambda_de.getPronunciationAccuracy(real_and_transcribed_words)
 
 
61
  self.assertEqual(int(pronunciation_accuracy), 71)
 
62
 
63
  def test_incorrect_transcription_with_correct_words_de(self):
64
  set_seed()
65
  phrase_real = phrases["de"]["real"]
66
  phrase_transcribed_incorrect = phrases["de"]["incorrect"]
67
- real_and_transcribed_words, _, _ = trainer_SST_lambda_de.matchSampleAndRecordedWords(phrase_real, phrase_transcribed_incorrect)
68
- pronunciation_accuracy, _ = trainer_SST_lambda_de.getPronunciationAccuracy(real_and_transcribed_words)
 
 
69
  self.assertEqual(int(pronunciation_accuracy), 71)
 
 
70
 
71
  def test_exact_transcription_en(self):
72
  set_seed()
73
  phrase_real = phrases["en"]["real"]
74
- real_and_transcribed_words, _, _ = trainer_SST_lambda_en.matchSampleAndRecordedWords(phrase_real, phrase_real)
75
- pronunciation_accuracy, _ = trainer_SST_lambda_en.getPronunciationAccuracy(real_and_transcribed_words)
 
 
76
  self.assertEqual(int(pronunciation_accuracy), 100)
 
77
 
78
  def test_transcription_en(self):
79
  set_seed()
80
  phrase_real = phrases["en"]["real"]
81
  phrase_transcribed = phrases["en"]["transcribed"]
82
- real_and_transcribed_words, _, _ = trainer_SST_lambda_en.matchSampleAndRecordedWords(phrase_real, phrase_transcribed)
83
- pronunciation_accuracy, _ = trainer_SST_lambda_en.getPronunciationAccuracy(real_and_transcribed_words)
 
 
84
  self.assertEqual(int(pronunciation_accuracy), 94)
 
85
 
86
  def test_partial_transcription_en(self):
87
  set_seed()
88
  phrase_real = phrases["en"]["real"]
89
  phrase_partial = phrases["en"]["partial"]
90
- real_and_transcribed_words, _, _ = trainer_SST_lambda_en.matchSampleAndRecordedWords(phrase_real, phrase_partial)
91
- pronunciation_accuracy, _ = trainer_SST_lambda_en.getPronunciationAccuracy(real_and_transcribed_words)
 
 
92
  self.assertEqual(int(pronunciation_accuracy), 56)
 
93
 
94
  def test_incorrect_transcription_with_correct_words_en(self):
95
  set_seed()
96
  phrase_real = phrases["en"]["real"]
97
  phrase_transcribed_incorrect = phrases["en"]["incorrect"]
98
- real_and_transcribed_words, _, _ = trainer_SST_lambda_en.matchSampleAndRecordedWords(phrase_real, phrase_transcribed_incorrect)
99
- pronunciation_accuracy, _ = trainer_SST_lambda_en.getPronunciationAccuracy(real_and_transcribed_words)
 
 
100
  self.assertEqual(int(pronunciation_accuracy), 69)
 
 
101
 
102
  def test_processAudioForGivenText_getTranscriptAndWordsLocations_de(self):
103
  set_seed()
@@ -202,6 +228,41 @@ class TestScore(unittest.TestCase):
202
  all_categories.append(category)
203
  self.assertEqual(all_categories, expected_categories)
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  if __name__ == '__main__':
207
  unittest.main()
 
40
  def test_exact_transcription_de(self):
41
  set_seed()
42
  phrase_real = phrases["de"]["real"]
43
+ real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices = trainer_SST_lambda_de.matchSampleAndRecordedWords(phrase_real, phrase_real)
44
+ self.assertEqual(real_and_transcribed_words_ipa, [('haloː,', 'haloː,'), ('viː', 'viː'), ('ɡeːt', 'ɡeːt'), ('ɛːs', 'ɛːs'), ('diːr?', 'diːr?')])
45
+ self.assertEqual(mapped_words_indices, [0, 1, 2, 3, 4])
46
+ pronunciation_accuracy, current_words_pronunciation_accuracy = trainer_SST_lambda_de.getPronunciationAccuracy(real_and_transcribed_words)
47
  self.assertEqual(int(pronunciation_accuracy), 100)
48
+ self.assertEqual(current_words_pronunciation_accuracy, [100, 100, 100, 100, 100])
49
 
50
  def test_transcription_de(self):
51
  set_seed()
52
  phrase_real = phrases["de"]["real"]
53
  phrase_transcribed = phrases["de"]["transcribed"]
54
+ real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices = trainer_SST_lambda_de.matchSampleAndRecordedWords(phrase_real, phrase_transcribed)
55
+ self.assertEqual(real_and_transcribed_words_ipa, [('haloː,', 'haloː'), ('viː', 'viː'), ('ɡeːt', 'ɡeːt'), ('ɛːs', 'ɛːs'), ('diːr?', 'diːɐ̯')])
56
+ self.assertEqual(mapped_words_indices, [0, 1, 2, 3, 4])
57
+ pronunciation_accuracy, current_words_pronunciation_accuracy= trainer_SST_lambda_de.getPronunciationAccuracy(real_and_transcribed_words)
58
  self.assertEqual(int(pronunciation_accuracy), 100)
59
+ self.assertEqual(current_words_pronunciation_accuracy, [100, 100, 100, 100, 100])
60
 
61
  def test_partial_transcription_de(self):
62
  set_seed()
63
  phrase_real = phrases["de"]["real"]
64
  phrase_partial = phrases["de"]["partial"]
65
+ real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices = trainer_SST_lambda_de.matchSampleAndRecordedWords(phrase_real, phrase_partial)
66
+ pronunciation_accuracy, current_words_pronunciation_accuracy= trainer_SST_lambda_de.getPronunciationAccuracy(real_and_transcribed_words)
67
+ self.assertEqual(real_and_transcribed_words_ipa, [('haloː,', 'haloː'), ('viː', 'viː'), ('ɡeːt', 'ɡeːt'), ('ɛːs', '-'), ('diːr?', '-')])
68
+ self.assertEqual(mapped_words_indices, [0, 1, 2, -1, -1])
69
  self.assertEqual(int(pronunciation_accuracy), 71)
70
+ self.assertEqual(current_words_pronunciation_accuracy, [100, 100, 100, 0, 0])
71
 
72
  def test_incorrect_transcription_with_correct_words_de(self):
73
  set_seed()
74
  phrase_real = phrases["de"]["real"]
75
  phrase_transcribed_incorrect = phrases["de"]["incorrect"]
76
+ real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices = trainer_SST_lambda_de.matchSampleAndRecordedWords(phrase_real, phrase_transcribed_incorrect)
77
+ self.assertEqual(real_and_transcribed_words_ipa, [('haloː,', 'haɪ̯l'), ('viː', 'viː'), ('ɡeːt', 'ɡiːt'), ('ɛːs', 'ɛːs'), ('diːr?', 'diːɐ̯')])
78
+ self.assertEqual(mapped_words_indices, [0, 1, 2, 3, 4])
79
+ pronunciation_accuracy, current_words_pronunciation_accuracy= trainer_SST_lambda_de.getPronunciationAccuracy(real_and_transcribed_words)
80
  self.assertEqual(int(pronunciation_accuracy), 71)
81
+ for accuracy, expected_accuracy in zip(current_words_pronunciation_accuracy, [60.0, 66.666666, 50.0, 100.0, 100.0]):
82
+ self.assertAlmostEqual(accuracy, expected_accuracy, places=2)
83
 
84
  def test_exact_transcription_en(self):
85
  set_seed()
86
  phrase_real = phrases["en"]["real"]
87
+ real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices = trainer_SST_lambda_en.matchSampleAndRecordedWords(phrase_real, phrase_real)
88
+ self.assertEqual(real_and_transcribed_words_ipa, [('haɪ', 'haɪ'), ('ðɛr,', 'ðɛr,'), ('haʊ', 'haʊ'), ('ər', 'ər'), ('ju?', 'ju?')])
89
+ self.assertEqual(mapped_words_indices, [0, 1, 2, 3, 4])
90
+ pronunciation_accuracy, current_words_pronunciation_accuracy= trainer_SST_lambda_en.getPronunciationAccuracy(real_and_transcribed_words)
91
  self.assertEqual(int(pronunciation_accuracy), 100)
92
+ self.assertEqual(current_words_pronunciation_accuracy, [100, 100, 100, 100, 100])
93
 
94
  def test_transcription_en(self):
95
  set_seed()
96
  phrase_real = phrases["en"]["real"]
97
  phrase_transcribed = phrases["en"]["transcribed"]
98
+ real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices = trainer_SST_lambda_en.matchSampleAndRecordedWords(phrase_real, phrase_transcribed)
99
+ self.assertEqual(real_and_transcribed_words_ipa, [('haɪ', 'aɪ'), ('ðɛr,', 'ðɛr'), ('haʊ', 'haʊ'), ('ər', 'ər'), ('ju?', 'ju')])
100
+ self.assertEqual(mapped_words_indices, [0, 1, 2, 3, 4])
101
+ pronunciation_accuracy, current_words_pronunciation_accuracy= trainer_SST_lambda_en.getPronunciationAccuracy(real_and_transcribed_words)
102
  self.assertEqual(int(pronunciation_accuracy), 94)
103
+ self.assertEqual(current_words_pronunciation_accuracy, [50.0, 100.0, 100.0, 100.0, 100.0])
104
 
105
  def test_partial_transcription_en(self):
106
  set_seed()
107
  phrase_real = phrases["en"]["real"]
108
  phrase_partial = phrases["en"]["partial"]
109
+ real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices = trainer_SST_lambda_en.matchSampleAndRecordedWords(phrase_real, phrase_partial)
110
+ self.assertEqual(real_and_transcribed_words_ipa, [('haɪ', 'aɪ'), ('ðɛr,', 'ðɛr'), ('haʊ', 'haʊ'), ('ər', ''), ('ju?', '')])
111
+ self.assertEqual(mapped_words_indices, [0, 1, 2, -1, -1])
112
+ pronunciation_accuracy, current_words_pronunciation_accuracy= trainer_SST_lambda_en.getPronunciationAccuracy(real_and_transcribed_words)
113
  self.assertEqual(int(pronunciation_accuracy), 56)
114
+ self.assertEqual(current_words_pronunciation_accuracy, [50.0, 100.0, 100.0, 0.0, 0.0])
115
 
116
  def test_incorrect_transcription_with_correct_words_en(self):
117
  set_seed()
118
  phrase_real = phrases["en"]["real"]
119
  phrase_transcribed_incorrect = phrases["en"]["incorrect"]
120
+ real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices = trainer_SST_lambda_en.matchSampleAndRecordedWords(phrase_real, phrase_transcribed_incorrect)
121
+ self.assertEqual(real_and_transcribed_words_ipa, [('haɪ', 'aɪ'), ('ðɛr,', 'hir'), ('haʊ', 'haʊ'), ('ər', 'ri'), ('ju?', 'juθ')])
122
+ self.assertEqual(mapped_words_indices, [0, 1, 2, 3, 4])
123
+ pronunciation_accuracy, current_words_pronunciation_accuracy= trainer_SST_lambda_en.getPronunciationAccuracy(real_and_transcribed_words)
124
  self.assertEqual(int(pronunciation_accuracy), 69)
125
+ for accuracy, expected_accuracy in zip(current_words_pronunciation_accuracy, [50.0, 80.0, 100.0, 66.666666, 33.333333]):
126
+ self.assertAlmostEqual(accuracy, expected_accuracy, places=2)
127
 
128
  def test_processAudioForGivenText_getTranscriptAndWordsLocations_de(self):
129
  set_seed()
 
228
  all_categories.append(category)
229
  self.assertEqual(all_categories, expected_categories)
230
 
231
+ def test_matchSampleAndRecordedWords(self):
232
+ set_seed()
233
+ phrase_real = phrases["de"]["real"]
234
+ phrase_transcribed = phrases["de"]["transcribed"]
235
+ real_and_transcribed_words, real_words, transcribed_words = trainer_SST_lambda_de.matchSampleAndRecordedWords(phrase_real, phrase_transcribed)
236
+ self.assertIsInstance(real_and_transcribed_words, list)
237
+ self.assertIsInstance(real_words, list)
238
+ self.assertIsInstance(transcribed_words, list)
239
+ self.assertEqual(len(real_and_transcribed_words), len(real_words))
240
+ self.assertEqual(len(real_and_transcribed_words), len(transcribed_words))
241
+
242
+ def test_removePunctuation_en(self):
243
+ word = "hello,"
244
+ cleaned_word = trainer_SST_lambda_en.removePunctuation(word)
245
+ self.assertEqual(cleaned_word, "hello")
246
+ word = "hello,\n\rworld..."
247
+ cleaned_word = trainer_SST_lambda_en.removePunctuation(word)
248
+ self.assertEqual(cleaned_word, "hello\n\rworld")
249
+
250
+ def test_getWordsPronunciationCategory_en(self):
251
+ accuracies = [x for x in range(-121, 121, 10)] + [np.inf, -np.inf, np.nan, 1.5, -1.5]
252
+ expected_categories = [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2]
253
+ categories = trainer_SST_lambda_en.getWordsPronunciationCategory(accuracies)
254
+ self.assertEqual(categories, expected_categories)
255
+
256
+ def test_preprocessAudio_en(self):
257
+ output_hash = utilities.hash_calculate(signal_en, is_file=False)
258
+ assert output_hash == b'zBAV/y7mecyPHLGiitHRP9vK7oU9hnYvyuatU0PQfts='
259
+ signal_transformed = transform(torch.Tensor(signal_en)).unsqueeze(0)
260
+ processed_audio = trainer_SST_lambda_en.preprocessAudio(signal_transformed)
261
+ self.assertIsInstance(processed_audio, torch.Tensor)
262
+ self.assertEqual(processed_audio.shape, (1, 16800))
263
+ output_hash = utilities.hash_calculate(processed_audio.numpy(), is_file=False)
264
+ assert output_hash == b'KsyH1MXIc+5e5B6CcijhitsGPUDRJjrJU2qg8bQi600='
265
+
266
 
267
  if __name__ == '__main__':
268
  unittest.main()