Spaces:
Runtime error
Runtime error
Introduce uncertainty to word error with PER threshold
Browse files
wav2vecasr/MispronounciationDetector.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from pandas.core.construction import T
|
2 |
import torch
|
3 |
import jiwer
|
|
|
4 |
|
5 |
class MispronounciationDetector:
|
6 |
def __init__(self, l2_phoneme_recogniser, g2p, device):
|
@@ -8,18 +9,19 @@ class MispronounciationDetector:
|
|
8 |
self.g2p = g2p
|
9 |
self.device = device
|
10 |
|
11 |
-
def detect(self, audio, text):
|
12 |
l2_phones = self.phoneme_asr_model.get_l2_phoneme_sequence(audio)
|
|
|
13 |
native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
|
14 |
standardised_native_speaker_phones = self.phoneme_asr_model.standardise_g2p_phoneme_sequence(native_speaker_phones)
|
15 |
-
raw_info = self.get_mispronounciation_output(text, l2_phones, standardised_native_speaker_phones)
|
16 |
return raw_info
|
17 |
|
18 |
def get_native_speaker_phoneme_sequence(self, text):
|
19 |
phonemes = self.g2p(text)
|
20 |
return phonemes
|
21 |
|
22 |
-
def get_mispronounciation_output(self, text, pred_phones, org_label_phones):
|
23 |
"""
|
24 |
Aligns the predicted phones from the L2 speaker and the expected native speaker phone to get the errors
|
25 |
:param text: original words read by the user
|
@@ -101,7 +103,7 @@ class MispronounciationDetector:
|
|
101 |
# get mispronounced words based on if there are phoneme errors present in the phonemes of that word
|
102 |
aligned_word_error_output = ""
|
103 |
words = text.split(" ")
|
104 |
-
word_error_bool = self.get_mispronounced_words(error_bool)
|
105 |
wer = sum(word_error_bool) / len(words)
|
106 |
|
107 |
raw_info = {"ref":ref, "hyp": hyp, "per":per, "phoneme_errors": error_bool, "wer": wer, "words": words, "word_errors":word_error_bool}
|
@@ -109,16 +111,27 @@ class MispronounciationDetector:
|
|
109 |
return raw_info
|
110 |
|
111 |
|
112 |
-
def get_mispronounced_words(self, phoneme_error_bool):
|
113 |
# map mispronounced phones back to words that were mispronounce to get WER
|
114 |
word_error_bool = []
|
115 |
phoneme_error_bool.append("|")
|
116 |
word_phones = self.split_lst_by_delim(phoneme_error_bool, "|")
|
|
|
|
|
117 |
for phones in word_phones:
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
word_error_bool.append(True)
|
120 |
else:
|
121 |
word_error_bool.append(False)
|
|
|
122 |
return word_error_bool
|
123 |
|
124 |
|
|
|
1 |
from pandas.core.construction import T
|
2 |
import torch
|
3 |
import jiwer
|
4 |
+
import re
|
5 |
|
6 |
class MispronounciationDetector:
|
7 |
def __init__(self, l2_phoneme_recogniser, g2p, device):
|
|
|
9 |
self.g2p = g2p
|
10 |
self.device = device
|
11 |
|
12 |
+
def detect(self, audio, text, phoneme_error_threshold=0.25):
|
13 |
l2_phones = self.phoneme_asr_model.get_l2_phoneme_sequence(audio)
|
14 |
+
l2_phones = [re.sub(r'\d', "", phone_str) for phone_str in l2_phones] #g2p has no lexical stress
|
15 |
native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
|
16 |
standardised_native_speaker_phones = self.phoneme_asr_model.standardise_g2p_phoneme_sequence(native_speaker_phones)
|
17 |
+
raw_info = self.get_mispronounciation_output(text, l2_phones, standardised_native_speaker_phones, phoneme_error_threshold)
|
18 |
return raw_info
|
19 |
|
20 |
def get_native_speaker_phoneme_sequence(self, text):
|
21 |
phonemes = self.g2p(text)
|
22 |
return phonemes
|
23 |
|
24 |
+
def get_mispronounciation_output(self, text, pred_phones, org_label_phones, phoneme_error_threshold):
|
25 |
"""
|
26 |
Aligns the predicted phones from the L2 speaker and the expected native speaker phone to get the errors
|
27 |
:param text: original words read by the user
|
|
|
103 |
# get mispronounced words based on if there are phoneme errors present in the phonemes of that word
|
104 |
aligned_word_error_output = ""
|
105 |
words = text.split(" ")
|
106 |
+
word_error_bool = self.get_mispronounced_words(error_bool, phoneme_error_threshold)
|
107 |
wer = sum(word_error_bool) / len(words)
|
108 |
|
109 |
raw_info = {"ref":ref, "hyp": hyp, "per":per, "phoneme_errors": error_bool, "wer": wer, "words": words, "word_errors":word_error_bool}
|
|
|
111 |
return raw_info
|
112 |
|
113 |
|
114 |
+
def get_mispronounced_words(self, phoneme_error_bool, phoneme_error_threshold):
|
115 |
# map mispronounced phones back to words that were mispronounce to get WER
|
116 |
word_error_bool = []
|
117 |
phoneme_error_bool.append("|")
|
118 |
word_phones = self.split_lst_by_delim(phoneme_error_bool, "|")
|
119 |
+
|
120 |
+
# wrong only if percentage of phones that are wrong > phoneme error threshold
|
121 |
for phones in word_phones:
|
122 |
+
|
123 |
+
# get count of "s", "d", "a" in phones
|
124 |
+
error_count = 0
|
125 |
+
for phone in phones:
|
126 |
+
if phone == "s" or phone == "d" or phone == "a":
|
127 |
+
error_count += 1
|
128 |
+
|
129 |
+
# check if pass threshold
|
130 |
+
if error_count / len(phones) > phoneme_error_threshold:
|
131 |
word_error_bool.append(True)
|
132 |
else:
|
133 |
word_error_bool.append(False)
|
134 |
+
|
135 |
return word_error_bool
|
136 |
|
137 |
|