Spaces:
Runtime error
Runtime error
from pandas.core.construction import T | |
import torch | |
import jiwer | |
import re | |
class MispronounciationDetector: | |
def __init__(self, l2_phoneme_recogniser, g2p, device): | |
self.phoneme_asr_model = l2_phoneme_recogniser # PhonemeASRModel class | |
self.g2p = g2p | |
self.device = device | |
def detect(self, audio, text, phoneme_error_threshold=0.25): | |
l2_phones = self.phoneme_asr_model.get_l2_phoneme_sequence(audio) | |
l2_phones = [re.sub(r'\d', "", phone_str) for phone_str in l2_phones] #g2p has no lexical stress | |
native_speaker_phones = self.get_native_speaker_phoneme_sequence(text) | |
standardised_native_speaker_phones = self.phoneme_asr_model.standardise_g2p_phoneme_sequence(native_speaker_phones) | |
raw_info = self.get_mispronounciation_output(text, l2_phones, standardised_native_speaker_phones, phoneme_error_threshold) | |
return raw_info | |
def get_native_speaker_phoneme_sequence(self, text): | |
phonemes = self.g2p(text) | |
return phonemes | |
def get_mispronounciation_output(self, text, pred_phones, org_label_phones, phoneme_error_threshold): | |
""" | |
Aligns the predicted phones from the L2 speaker and the expected native speaker phone to get the errors | |
:param text: original words read by the user | |
:type text: string | |
:param pred_phones: predicted phonemes by L2 speaker from ASR Model | |
:type pred_phones: array | |
:param org_label_phones: correct, native speaker phonemes from G2P where phonemes of each word is segregated by " " | |
:type org_label_phones: array | |
:return: dictionary containing various mispronounciation information like PER, WER and error boolean arrays at phoneme/word level | |
:rtype: dictionary | |
""" | |
# get per | |
label_phones = [phone for phone in org_label_phones if phone != " "] | |
reference = " ".join(label_phones) # dummy phones | |
hypothesis = " ".join(pred_phones) # dummy l2 speaker phones | |
res = jiwer.process_words(reference, hypothesis) | |
per = res.wer | |
# print(jiwer.visualize_alignment(res)) | |
# get phoneme alignments | |
alignments = res.alignments | |
error_bool = [] | |
ref, hyp = [],[] | |
for alignment_chunk in alignments[0]: | |
alignment_type = alignment_chunk.type | |
ref_start_idx = alignment_chunk.ref_start_idx | |
ref_end_idx = alignment_chunk.ref_end_idx | |
hyp_start_idx = alignment_chunk.hyp_start_idx | |
hyp_end_idx = alignment_chunk.hyp_end_idx | |
if alignment_type != "equal": | |
if alignment_type == "insert": | |
for i in range(hyp_start_idx, hyp_end_idx): | |
ref.append("*" * len(pred_phones[i])) | |
space_padding = " " * (len(pred_phones[i])-1) | |
error_bool.append(space_padding + "a") | |
hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx]) | |
elif alignment_type == "delete": | |
ref.extend(label_phones[ref_start_idx:ref_end_idx]) | |
for i in range(ref_start_idx, ref_end_idx): | |
hyp.append("*" * len(label_phones[i])) | |
space_padding = " " * (len(label_phones[i])-1) | |
error_bool.append(space_padding + alignment_type[0]) | |
else: | |
for i in range(ref_end_idx - ref_start_idx): | |
correct_phone = label_phones[ref_start_idx+i] | |
pred_phone = pred_phones[hyp_start_idx+i] | |
if len(correct_phone) > len(pred_phone): | |
space_padding = " " * (len(correct_phone) - len(pred_phone)) | |
ref.append(correct_phone) | |
hyp.append(space_padding + pred_phone) | |
error_bool.append(" " * (len(correct_phone)-1) + alignment_type[0]) | |
else: | |
space_padding = " " * (len(pred_phone) - len(correct_phone)) | |
ref.append(space_padding + correct_phone) | |
hyp.append(pred_phone) | |
error_bool.append(" " * (len(pred_phone)-1) + alignment_type[0]) | |
else: | |
ref.extend(label_phones[ref_start_idx:ref_end_idx]) | |
hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx]) | |
# ref or hyp does not matter | |
for i in range(ref_start_idx, ref_end_idx): | |
space_padding = "-" * (len(label_phones[i])) | |
error_bool.append(space_padding) | |
# insert word delimiters to show user phoneme sections by word | |
delimiter_idx = 0 | |
for phone in org_label_phones: | |
if phone == " ": | |
hyp.insert(delimiter_idx+1, "|") | |
ref.insert(delimiter_idx+1, "|") | |
error_bool.insert(delimiter_idx+1, "|") | |
continue | |
while delimiter_idx < len(ref) and ref[delimiter_idx].strip() != phone: | |
delimiter_idx += 1 | |
# word ends | |
ref.append("|") | |
hyp.append("|") | |
# get mispronounced words based on if there are phoneme errors present in the phonemes of that word | |
aligned_word_error_output = "" | |
words = text.split(" ") | |
word_error_bool = self.get_mispronounced_words(error_bool, phoneme_error_threshold) | |
wer = sum(word_error_bool) / len(words) | |
raw_info = {"ref":ref, "hyp": hyp, "per":per, "phoneme_errors": error_bool, "wer": wer, "words": words, "word_errors":word_error_bool} | |
return raw_info | |
def get_mispronounced_words(self, phoneme_error_bool, phoneme_error_threshold): | |
# map mispronounced phones back to words that were mispronounce to get WER | |
word_error_bool = [] | |
phoneme_error_bool.append("|") | |
word_phones = self.split_lst_by_delim(phoneme_error_bool, "|") | |
# wrong only if percentage of phones that are wrong > phoneme error threshold | |
for phones in word_phones: | |
# get count of "s", "d", "a" in phones | |
error_count = 0 | |
for phone in phones: | |
if phone == "s" or phone == "d" or phone == "a": | |
error_count += 1 | |
# check if pass threshold | |
if error_count / len(phones) > phoneme_error_threshold: | |
word_error_bool.append(True) | |
else: | |
word_error_bool.append(False) | |
return word_error_bool | |
def split_lst_by_delim(self, lst, delimiter): | |
temp = [] | |
res = [] | |
for item in lst: | |
if item != delimiter: | |
temp.append(item.strip()) | |
else: | |
res.append(temp); | |
temp = [] | |
return res |