from pandas.core.construction import T import torch import jiwer import re class MispronounciationDetector: def __init__(self, l2_phoneme_recogniser, g2p, device): self.phoneme_asr_model = l2_phoneme_recogniser # PhonemeASRModel class self.g2p = g2p self.device = device def detect(self, audio, text, phoneme_error_threshold=0.25): l2_phones = self.phoneme_asr_model.get_l2_phoneme_sequence(audio) l2_phones = [re.sub(r'\d', "", phone_str) for phone_str in l2_phones] #g2p has no lexical stress native_speaker_phones = self.get_native_speaker_phoneme_sequence(text) standardised_native_speaker_phones = self.phoneme_asr_model.standardise_g2p_phoneme_sequence(native_speaker_phones) raw_info = self.get_mispronounciation_output(text, l2_phones, standardised_native_speaker_phones, phoneme_error_threshold) return raw_info def get_native_speaker_phoneme_sequence(self, text): phonemes = self.g2p(text) return phonemes def get_mispronounciation_output(self, text, pred_phones, org_label_phones, phoneme_error_threshold): """ Aligns the predicted phones from the L2 speaker and the expected native speaker phone to get the errors :param text: original words read by the user :type text: string :param pred_phones: predicted phonemes by L2 speaker from ASR Model :type pred_phones: array :param org_label_phones: correct, native speaker phonemes from G2P where phonemes of each word is segregated by " " :type org_label_phones: array :return: dictionary containing various mispronounciation information like PER, WER and error boolean arrays at phoneme/word level :rtype: dictionary """ # get per label_phones = [phone for phone in org_label_phones if phone != " "] reference = " ".join(label_phones) # dummy phones hypothesis = " ".join(pred_phones) # dummy l2 speaker phones res = jiwer.process_words(reference, hypothesis) per = res.wer # print(jiwer.visualize_alignment(res)) # get phoneme alignments alignments = res.alignments error_bool = [] ref, hyp = [],[] for alignment_chunk in alignments[0]: alignment_type = alignment_chunk.type ref_start_idx = alignment_chunk.ref_start_idx ref_end_idx = alignment_chunk.ref_end_idx hyp_start_idx = alignment_chunk.hyp_start_idx hyp_end_idx = alignment_chunk.hyp_end_idx if alignment_type != "equal": if alignment_type == "insert": for i in range(hyp_start_idx, hyp_end_idx): ref.append("*" * len(pred_phones[i])) space_padding = " " * (len(pred_phones[i])-1) error_bool.append(space_padding + "a") hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx]) elif alignment_type == "delete": ref.extend(label_phones[ref_start_idx:ref_end_idx]) for i in range(ref_start_idx, ref_end_idx): hyp.append("*" * len(label_phones[i])) space_padding = " " * (len(label_phones[i])-1) error_bool.append(space_padding + alignment_type[0]) else: for i in range(ref_end_idx - ref_start_idx): correct_phone = label_phones[ref_start_idx+i] pred_phone = pred_phones[hyp_start_idx+i] if len(correct_phone) > len(pred_phone): space_padding = " " * (len(correct_phone) - len(pred_phone)) ref.append(correct_phone) hyp.append(space_padding + pred_phone) error_bool.append(" " * (len(correct_phone)-1) + alignment_type[0]) else: space_padding = " " * (len(pred_phone) - len(correct_phone)) ref.append(space_padding + correct_phone) hyp.append(pred_phone) error_bool.append(" " * (len(pred_phone)-1) + alignment_type[0]) else: ref.extend(label_phones[ref_start_idx:ref_end_idx]) hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx]) # ref or hyp does not matter for i in range(ref_start_idx, ref_end_idx): space_padding = "-" * (len(label_phones[i])) error_bool.append(space_padding) # insert word delimiters to show user phoneme sections by word delimiter_idx = 0 for phone in org_label_phones: if phone == " ": hyp.insert(delimiter_idx+1, "|") ref.insert(delimiter_idx+1, "|") error_bool.insert(delimiter_idx+1, "|") continue while delimiter_idx < len(ref) and ref[delimiter_idx].strip() != phone: delimiter_idx += 1 # word ends ref.append("|") hyp.append("|") # get mispronounced words based on if there are phoneme errors present in the phonemes of that word aligned_word_error_output = "" words = text.split(" ") word_error_bool = self.get_mispronounced_words(error_bool, phoneme_error_threshold) wer = sum(word_error_bool) / len(words) raw_info = {"ref":ref, "hyp": hyp, "per":per, "phoneme_errors": error_bool, "wer": wer, "words": words, "word_errors":word_error_bool} return raw_info def get_mispronounced_words(self, phoneme_error_bool, phoneme_error_threshold): # map mispronounced phones back to words that were mispronounce to get WER word_error_bool = [] phoneme_error_bool.append("|") word_phones = self.split_lst_by_delim(phoneme_error_bool, "|") # wrong only if percentage of phones that are wrong > phoneme error threshold for phones in word_phones: # get count of "s", "d", "a" in phones error_count = 0 for phone in phones: if phone == "s" or phone == "d" or phone == "a": error_count += 1 # check if pass threshold if error_count / len(phones) > phoneme_error_threshold: word_error_bool.append(True) else: word_error_bool.append(False) return word_error_bool def split_lst_by_delim(self, lst, delimiter): temp = [] res = [] for item in lst: if item != delimiter: temp.append(item.strip()) else: res.append(temp); temp = [] return res