Spaces:
Runtime error
Runtime error
File size: 5,539 Bytes
6d3dc99 22efacc 6d3dc99 22efacc 6d3dc99 22efacc 6d3dc99 22efacc 6d3dc99 22efacc 6d3dc99 22efacc 6d3dc99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from pandas.core.construction import T
import torch
import jiwer
class MispronounciationDetector:
def __init__(self, l2_phoneme_recogniser, g2p, device):
self.phoneme_asr_model = l2_phoneme_recogniser # PhonemeASRModel class
self.g2p = g2p
self.device = device
def detect(self, audio, text):
l2_phones = self.phoneme_asr_model.get_l2_phoneme_sequence(audio)
native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
standardised_native_speaker_phones = self.phoneme_asr_model.standardise_g2p_phoneme_sequence(native_speaker_phones)
raw_info = self.get_mispronounciation_output(text, l2_phones, standardised_native_speaker_phones)
return raw_info
def get_native_speaker_phoneme_sequence(self, text):
phonemes = self.g2p(text)
return phonemes
def get_mispronounciation_output(self, text, pred_phones, org_label_phones):
"""
Aligns the predicted phones from the L2 speaker and the expected native speaker phone to get the errors
:param text: original words read by the user
:type text: string
:param pred_phones: predicted phonemes by L2 speaker from ASR Model
:type pred_phones: array
:param org_label_phones: correct, native speaker phonemes from G2P where phonemes of each word is segregated by " "
:type org_label_phones: array
:return: dictionary containing various mispronounciation information like PER, WER and error boolean arrays at phoneme/word level
:rtype: dictionary
"""
# get per
label_phones = [phone for phone in org_label_phones if phone != " "]
reference = " ".join(label_phones) # dummy phones
hypothesis = " ".join(pred_phones) # dummy l2 speaker phones
res = jiwer.process_words(reference, hypothesis)
per = res.wer
# print(jiwer.visualize_alignment(res))
# get phoneme alignments
alignments = res.alignments
error_bool = []
ref, hyp = [],[]
for alignment_chunk in alignments[0]:
alignment_type = alignment_chunk.type
ref_start_idx = alignment_chunk.ref_start_idx
ref_end_idx = alignment_chunk.ref_end_idx
hyp_start_idx = alignment_chunk.hyp_start_idx
hyp_end_idx = alignment_chunk.hyp_end_idx
if alignment_type != "equal":
if alignment_type == "insert":
for i in range(hyp_start_idx, hyp_end_idx):
ref.append("*" * len(pred_phones[i]))
space_padding = " " * (len(pred_phones[i])-1)
error_bool.append(space_padding + "a")
hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx])
elif alignment_type == "delete":
ref.extend(label_phones[ref_start_idx:ref_end_idx])
for i in range(ref_start_idx, ref_end_idx):
hyp.append("*" * len(label_phones[i]))
space_padding = " " * (len(label_phones[i])-1)
error_bool.append(space_padding + alignment_type[0])
else:
for i in range(ref_end_idx - ref_start_idx):
correct_phone = label_phones[ref_start_idx+i]
pred_phone = pred_phones[hyp_start_idx+i]
if len(correct_phone) > len(pred_phone):
space_padding = " " * (len(correct_phone) - len(pred_phone))
ref.append(correct_phone)
hyp.append(space_padding + pred_phone)
error_bool.append(" " * (len(correct_phone)-1) + alignment_type[0])
else:
space_padding = " " * (len(pred_phone) - len(correct_phone))
ref.append(space_padding + correct_phone)
hyp.append(pred_phone)
error_bool.append(" " * (len(pred_phone)-1) + alignment_type[0])
else:
ref.extend(label_phones[ref_start_idx:ref_end_idx])
hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx])
# ref or hyp does not matter
for i in range(ref_start_idx, ref_end_idx):
space_padding = "-" * (len(label_phones[i]))
error_bool.append(space_padding)
# insert word delimiters to show user phoneme sections by word
delimiter_idx = 0
for phone in org_label_phones:
if phone == " ":
hyp.insert(delimiter_idx+1, "|")
ref.insert(delimiter_idx+1, "|")
error_bool.insert(delimiter_idx+1, "|")
continue
while delimiter_idx < len(ref) and ref[delimiter_idx].strip() != phone:
delimiter_idx += 1
# word ends
ref.append("|")
hyp.append("|")
# get mispronounced words based on if there are phoneme errors present in the phonemes of that word
aligned_word_error_output = ""
words = text.split(" ")
word_error_bool = self.get_mispronounced_words(error_bool)
wer = sum(word_error_bool) / len(words)
raw_info = {"ref":ref, "hyp": hyp, "per":per, "phoneme_errors": error_bool, "wer": wer, "words": words, "word_errors":word_error_bool}
return raw_info
def get_mispronounced_words(self, phoneme_error_bool):
# map mispronounced phones back to words that were mispronounce to get WER
word_error_bool = []
phoneme_error_bool.append("|")
word_phones = self.split_lst_by_delim(phoneme_error_bool, "|")
for phones in word_phones:
if "s" in phones or "d" in phones or "a" in phones:
word_error_bool.append(True)
else:
word_error_bool.append(False)
return word_error_bool
def split_lst_by_delim(self, lst, delimiter):
temp = []
res = []
for item in lst:
if item != delimiter:
temp.append(item.strip())
else:
res.append(temp);
temp = []
return res |