File size: 5,167 Bytes
6d3dc99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from pandas.core.construction import T
import torch
import jiwer

class MispronounciationDetector:
  def __init__(self, l2_phoneme_recogniser, l2_phoneme_recogniser_processor, g2p, device):
    self.l2_phoneme_recogniser = l2_phoneme_recogniser
    self.l2_phoneme_recogniser_processor = l2_phoneme_recogniser_processor
    self.g2p = g2p
    self.device = device

  def detect(self, audio, text):
    l2_phones = self.get_l2_phoneme_sequence(audio)
    native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
    raw_info = self.get_mispronounciation_output(text, l2_phones, native_speaker_phones)
    return raw_info

  def get_l2_phoneme_sequence(self, audio):
    input_dict = self.l2_phoneme_recogniser_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
    logits = self.l2_phoneme_recogniser(input_dict.input_values.to(self.device)).logits
    pred_ids = torch.argmax(logits, dim=-1)[0]
    pred_phones = [phoneme for phoneme in self.l2_phoneme_recogniser_processor.batch_decode(pred_ids) if phoneme != ""]
    return pred_phones

  def get_native_speaker_phoneme_sequence(self, text):
    phonemes = self.g2p(text)
    return phonemes

  def get_mispronounciation_output(self, text, pred_phones, org_label_phones):
    # get per
    label_phones = [phone for phone in org_label_phones if phone != " "]
    reference = " ".join(label_phones) # dummy phones
    hypothesis = " ".join(pred_phones) # dummy l2 speaker phones
    res = jiwer.process_words(reference, hypothesis)
    per = res.wer
    # print(jiwer.visualize_alignment(res))

    # get phoneme alignments
    alignments = res.alignments
    error_bool = []
    ref, hyp = [],[]
    for alignment_chunk in alignments[0]:
      alignment_type = alignment_chunk.type
      ref_start_idx = alignment_chunk.ref_start_idx
      ref_end_idx = alignment_chunk.ref_end_idx
      hyp_start_idx = alignment_chunk.hyp_start_idx
      hyp_end_idx = alignment_chunk.hyp_end_idx
      if alignment_type != "equal":
        if alignment_type == "insert":
          for i in range(hyp_start_idx, hyp_end_idx):
            ref.append("*" * len(pred_phones[i]))
            space_padding = " " * (len(pred_phones[i])-1)
            error_bool.append(space_padding + "a")
          hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx])
        elif alignment_type == "delete":
          ref.extend(label_phones[ref_start_idx:ref_end_idx])
          for i in range(ref_start_idx, ref_end_idx):
            hyp.append("*" * len(label_phones[i]))
            space_padding = " " * (len(label_phones[i])-1)
            error_bool.append(space_padding + alignment_type[0])
        else:
          for i in range(ref_end_idx - ref_start_idx):
            correct_phone = label_phones[ref_start_idx+i]
            pred_phone = pred_phones[hyp_start_idx+i]
            if len(correct_phone) > len(pred_phone):
              space_padding = " " * (len(correct_phone) - len(pred_phone))
              ref.append(correct_phone)
              hyp.append(space_padding + pred_phone)
              error_bool.append(" " * (len(correct_phone)-1) + alignment_type[0])
            else:
              space_padding = " " * (len(pred_phone) - len(correct_phone))
              ref.append(space_padding + correct_phone)
              hyp.append(pred_phone)
              error_bool.append(" " * (len(pred_phone)-1) + alignment_type[0])
      else:
        ref.extend(label_phones[ref_start_idx:ref_end_idx])
        hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx])
        # ref or hyp does not matter
        for i in range(ref_start_idx, ref_end_idx):
          space_padding = "-" * (len(label_phones[i]))
          error_bool.append(space_padding)

    delimiter_idx = 0
    for phone in org_label_phones:
      if phone == " ":
        hyp.insert(delimiter_idx+1, "|")
        ref.insert(delimiter_idx+1, "|")
        error_bool.insert(delimiter_idx+1, "|")
        continue
      while delimiter_idx < len(ref) and ref[delimiter_idx].strip() != phone:
        delimiter_idx += 1
    # word ends
    ref.append("|")
    hyp.append("|")

    # get mispronounced words
    aligned_word_error_output = ""
    words = text.split(" ")
    word_error_bool = self.get_mispronounced_words(error_bool)
    wer = sum(word_error_bool) / len(words)

    raw_info = {"ref":ref, "hyp": hyp, "per":per, "phoneme_errors": error_bool, "wer": wer, "words": words, "word_errors":word_error_bool}

    return raw_info


  def get_mispronounced_words(self, phoneme_error_bool):
    # map mispronounced phones back to words that were mispronounce to get WER
    word_error_bool = []
    phoneme_error_bool.append("|")
    word_phones = self.split_lst_by_delim(phoneme_error_bool, "|")
    for phones in word_phones:
      if "s" in phones or "d" in phones or "a" in phones:
        word_error_bool.append(True)
      else:
        word_error_bool.append(False)
    return word_error_bool


  def split_lst_by_delim(self, lst, delimiter):
    temp = []
    res = []
    for item in lst:
      if item != delimiter:
        temp.append(item.strip())
      else:
        res.append(temp);
        temp = []
    return res