Spaces:
Runtime error
Runtime error
File size: 5,167 Bytes
6d3dc99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from pandas.core.construction import T
import torch
import jiwer
class MispronounciationDetector:
def __init__(self, l2_phoneme_recogniser, l2_phoneme_recogniser_processor, g2p, device):
self.l2_phoneme_recogniser = l2_phoneme_recogniser
self.l2_phoneme_recogniser_processor = l2_phoneme_recogniser_processor
self.g2p = g2p
self.device = device
def detect(self, audio, text):
l2_phones = self.get_l2_phoneme_sequence(audio)
native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
raw_info = self.get_mispronounciation_output(text, l2_phones, native_speaker_phones)
return raw_info
def get_l2_phoneme_sequence(self, audio):
input_dict = self.l2_phoneme_recogniser_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
logits = self.l2_phoneme_recogniser(input_dict.input_values.to(self.device)).logits
pred_ids = torch.argmax(logits, dim=-1)[0]
pred_phones = [phoneme for phoneme in self.l2_phoneme_recogniser_processor.batch_decode(pred_ids) if phoneme != ""]
return pred_phones
def get_native_speaker_phoneme_sequence(self, text):
phonemes = self.g2p(text)
return phonemes
def get_mispronounciation_output(self, text, pred_phones, org_label_phones):
# get per
label_phones = [phone for phone in org_label_phones if phone != " "]
reference = " ".join(label_phones) # dummy phones
hypothesis = " ".join(pred_phones) # dummy l2 speaker phones
res = jiwer.process_words(reference, hypothesis)
per = res.wer
# print(jiwer.visualize_alignment(res))
# get phoneme alignments
alignments = res.alignments
error_bool = []
ref, hyp = [],[]
for alignment_chunk in alignments[0]:
alignment_type = alignment_chunk.type
ref_start_idx = alignment_chunk.ref_start_idx
ref_end_idx = alignment_chunk.ref_end_idx
hyp_start_idx = alignment_chunk.hyp_start_idx
hyp_end_idx = alignment_chunk.hyp_end_idx
if alignment_type != "equal":
if alignment_type == "insert":
for i in range(hyp_start_idx, hyp_end_idx):
ref.append("*" * len(pred_phones[i]))
space_padding = " " * (len(pred_phones[i])-1)
error_bool.append(space_padding + "a")
hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx])
elif alignment_type == "delete":
ref.extend(label_phones[ref_start_idx:ref_end_idx])
for i in range(ref_start_idx, ref_end_idx):
hyp.append("*" * len(label_phones[i]))
space_padding = " " * (len(label_phones[i])-1)
error_bool.append(space_padding + alignment_type[0])
else:
for i in range(ref_end_idx - ref_start_idx):
correct_phone = label_phones[ref_start_idx+i]
pred_phone = pred_phones[hyp_start_idx+i]
if len(correct_phone) > len(pred_phone):
space_padding = " " * (len(correct_phone) - len(pred_phone))
ref.append(correct_phone)
hyp.append(space_padding + pred_phone)
error_bool.append(" " * (len(correct_phone)-1) + alignment_type[0])
else:
space_padding = " " * (len(pred_phone) - len(correct_phone))
ref.append(space_padding + correct_phone)
hyp.append(pred_phone)
error_bool.append(" " * (len(pred_phone)-1) + alignment_type[0])
else:
ref.extend(label_phones[ref_start_idx:ref_end_idx])
hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx])
# ref or hyp does not matter
for i in range(ref_start_idx, ref_end_idx):
space_padding = "-" * (len(label_phones[i]))
error_bool.append(space_padding)
delimiter_idx = 0
for phone in org_label_phones:
if phone == " ":
hyp.insert(delimiter_idx+1, "|")
ref.insert(delimiter_idx+1, "|")
error_bool.insert(delimiter_idx+1, "|")
continue
while delimiter_idx < len(ref) and ref[delimiter_idx].strip() != phone:
delimiter_idx += 1
# word ends
ref.append("|")
hyp.append("|")
# get mispronounced words
aligned_word_error_output = ""
words = text.split(" ")
word_error_bool = self.get_mispronounced_words(error_bool)
wer = sum(word_error_bool) / len(words)
raw_info = {"ref":ref, "hyp": hyp, "per":per, "phoneme_errors": error_bool, "wer": wer, "words": words, "word_errors":word_error_bool}
return raw_info
def get_mispronounced_words(self, phoneme_error_bool):
# map mispronounced phones back to words that were mispronounce to get WER
word_error_bool = []
phoneme_error_bool.append("|")
word_phones = self.split_lst_by_delim(phoneme_error_bool, "|")
for phones in word_phones:
if "s" in phones or "d" in phones or "a" in phones:
word_error_bool.append(True)
else:
word_error_bool.append(False)
return word_error_bool
def split_lst_by_delim(self, lst, delimiter):
temp = []
res = []
for item in lst:
if item != delimiter:
temp.append(item.strip())
else:
res.append(temp);
temp = []
return res |