Spaces:
Running
Running
import time | |
from string import punctuation | |
import epitran | |
import numpy as np | |
import torch | |
import ModelInterfaces as mi | |
import RuleBasedModels | |
import WordMatching as wm | |
import WordMetrics | |
import models as mo | |
from constants import app_logger, MODEL_NAME_DEFAULT, sample_rate_resample | |
def preprocessAudioStandalone(audio: torch.tensor) -> torch.tensor: | |
""" | |
Preprocess the audio by normalizing it. | |
Args: | |
audio (torch.tensor): The input audio tensor. | |
Returns: | |
torch.tensor: The normalized audio tensor. | |
""" | |
audio = audio-torch.mean(audio) | |
audio = audio/torch.max(torch.abs(audio)) | |
return audio | |
class PronunciationTrainer: | |
""" | |
A class to train and evaluate pronunciation accuracy using ASR and phoneme conversion models. | |
""" | |
current_transcript: str | |
current_ipa: str | |
current_recorded_audio: torch.Tensor | |
current_recorded_transcript: str | |
current_recorded_word_locations: list | |
current_recorded_intonations: torch.tensor | |
current_words_pronunciation_accuracy = [] | |
categories_thresholds = np.array([80, 60, 59]) | |
sampling_rate = sample_rate_resample | |
def __init__(self, asr_model: mi.IASRModel, word_to_ipa_coverter: mi.ITextToPhonemModel) -> None: | |
""" | |
Initialize the PronunciationTrainer with ASR and phoneme conversion models. | |
Args: | |
asr_model (mi.IASRModel): The ASR model to use. | |
word_to_ipa_coverter (mi.ITextToPhonemModel): The phoneme conversion model to use. | |
""" | |
self.asr_model = asr_model | |
self.ipa_converter = word_to_ipa_coverter | |
def getTranscriptAndWordsLocations(self, audio_length_in_samples: int) -> tuple[str, list]: | |
""" | |
Get the transcript and word locations from the ASR model. | |
Args: | |
audio_length_in_samples (int): The length of the audio in samples. | |
Returns: | |
tuple: A tuple containing the audio transcript and word locations in samples. | |
""" | |
audio_transcript = self.asr_model.getTranscript() | |
word_locations_in_samples = self.asr_model.getWordLocations() | |
fade_duration_in_samples = 0.05*self.sampling_rate | |
word_locations_in_samples = [(int(np.maximum(0, word['start_ts']-fade_duration_in_samples)), int(np.minimum( | |
audio_length_in_samples-1, word['end_ts']+fade_duration_in_samples))) for word in word_locations_in_samples] | |
return audio_transcript, word_locations_in_samples | |
# def getWordsRelativeIntonation(self, Audio: torch.tensor, word_locations: list): | |
# intonations = torch.zeros((len(word_locations), 1)) | |
# intonation_fade_samples = 0.3*self.sampling_rate | |
# app_logger.info(f"intonations.shape: {intonations.shape}.") | |
# for word in range(len(word_locations)): | |
# intonation_start = int(np.maximum( | |
# 0, word_locations[word][0]-intonation_fade_samples)) | |
# intonation_end = int(np.minimum( | |
# Audio.shape[1]-1, word_locations[word][1]+intonation_fade_samples)) | |
# intonations[word] = torch.sqrt(torch.mean( | |
# Audio[0][intonation_start:intonation_end]**2)) | |
# | |
# intonations = intonations/torch.mean(intonations) | |
# return intonations | |
##################### ASR Functions ########################### | |
def processAudioForGivenText(self, recordedAudio: torch.Tensor = None, real_text=None) -> dict: | |
""" | |
Process the recorded audio and evaluate pronunciation accuracy. | |
Args: | |
recordedAudio (torch.Tensor, optional): The recorded audio tensor. Defaults to None. | |
real_text (str, optional): The real text to compare against. Defaults to None. | |
Returns: | |
dict: A dictionary containing the evaluation results. | |
""" | |
start = time.time() | |
recording_transcript, recording_ipa, word_locations = self.getAudioTranscript( | |
recordedAudio) | |
time_transcript_audio = time.time() - start | |
app_logger.info(f'Time for NN to transcript audio: {time_transcript_audio:.2f}.') | |
start = time.time() | |
real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices = self.matchSampleAndRecordedWords( | |
real_text, recording_transcript) | |
time_matching_transcripts = time.time() - start | |
app_logger.info(f'Time for matching transcripts: {time_matching_transcripts:.3f}.') | |
start_time, end_time = self.getWordLocationsFromRecordInSeconds( | |
word_locations, mapped_words_indices) | |
pronunciation_accuracy, current_words_pronunciation_accuracy = self.getPronunciationAccuracy( | |
real_and_transcribed_words) # _ipa | |
pronunciation_categories = self.getWordsPronunciationCategory( | |
current_words_pronunciation_accuracy) | |
result = {'recording_transcript': recording_transcript, | |
'real_and_transcribed_words': real_and_transcribed_words, | |
'recording_ipa': recording_ipa, 'start_time': start_time, 'end_time': end_time, | |
'real_and_transcribed_words_ipa': real_and_transcribed_words_ipa, 'pronunciation_accuracy': pronunciation_accuracy, | |
'pronunciation_categories': pronunciation_categories} | |
return result | |
def getAudioTranscript(self, recordedAudio: torch.Tensor = None) -> tuple[str | list]: | |
""" | |
Get the transcript and IPA representation of the recorded audio. | |
Args: | |
recordedAudio (torch.Tensor, optional): The recorded audio tensor. Defaults to None. | |
Returns: | |
tuple: A tuple containing the transcript, IPA representation, and word locations. | |
""" | |
current_recorded_audio = recordedAudio | |
current_recorded_audio = self.preprocessAudio( | |
current_recorded_audio) | |
self.asr_model.processAudio(current_recorded_audio) | |
current_recorded_transcript, current_recorded_word_locations = self.getTranscriptAndWordsLocations( | |
current_recorded_audio.shape[1]) | |
current_recorded_ipa = self.ipa_converter.convertToPhonem( | |
current_recorded_transcript) | |
# time.sleep(10000) | |
return current_recorded_transcript, current_recorded_ipa, current_recorded_word_locations | |
def getWordLocationsFromRecordInSeconds(self, word_locations, mapped_words_indices) -> list: | |
""" | |
Get the start and end times of words in the recorded audio in seconds. | |
Args: | |
word_locations (list): The word locations in samples. | |
mapped_words_indices (list): The indices of the mapped words. | |
Returns: | |
list: A list containing the start and end times of words in seconds. | |
""" | |
app_logger.info(f"len_list: word_locations:{len(word_locations)}, mapped_words_indices:{len(mapped_words_indices)}, {len(word_locations) == len(mapped_words_indices)}...") | |
start_time = [] | |
end_time = [] | |
for word_idx in range(len(mapped_words_indices)): | |
start_time.append(float(word_locations[mapped_words_indices[word_idx]] | |
[0])/self.sampling_rate) | |
end_time.append(float(word_locations[mapped_words_indices[word_idx]] | |
[1])/self.sampling_rate) | |
return ' '.join([str(time) for time in start_time]), ' '.join([str(time) for time in end_time]) | |
##################### END ASR Functions ########################### | |
##################### Evaluation Functions ########################### | |
def matchSampleAndRecordedWords(self, real_text, recorded_transcript): | |
""" | |
Match the real text with the recorded transcript and get the IPA representations. | |
Args: | |
real_text (str): The real text to compare against. | |
recorded_transcript (str): The recorded transcript. | |
Returns: | |
tuple: A tuple containing the matched words, IPA representations, and mapped word indices. | |
""" | |
words_estimated = recorded_transcript.split() | |
try: | |
words_real = real_text.split() | |
except AttributeError: | |
raise ValueError("Real text is None, but should be a string.") | |
mapped_words, mapped_words_indices = wm.get_best_mapped_words( | |
words_estimated, words_real) | |
real_and_transcribed_words = [] | |
real_and_transcribed_words_ipa = [] | |
for word_idx in range(len(words_real)): | |
if word_idx >= len(mapped_words)-1: | |
mapped_words.append('-') | |
real_and_transcribed_words.append( | |
(words_real[word_idx], mapped_words[word_idx])) | |
real_and_transcribed_words_ipa.append((self.ipa_converter.convertToPhonem(words_real[word_idx]), | |
self.ipa_converter.convertToPhonem(mapped_words[word_idx]))) | |
return real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices | |
def getPronunciationAccuracy(self, real_and_transcribed_words_ipa) -> float: | |
""" | |
Calculate the pronunciation accuracy based on the IPA representations. | |
Args: | |
real_and_transcribed_words_ipa (list): A list of tuples containing the real and transcribed IPA representations. | |
Returns: | |
float: The percentage of correct pronunciations. | |
""" | |
total_mismatches = 0. | |
number_of_phonemes = 0. | |
current_words_pronunciation_accuracy = [] | |
for pair in real_and_transcribed_words_ipa: | |
real_without_punctuation = self.removePunctuation(pair[0]).lower() | |
number_of_word_mismatches = WordMetrics.edit_distance_python( | |
real_without_punctuation, self.removePunctuation(pair[1]).lower()) | |
total_mismatches += number_of_word_mismatches | |
number_of_phonemes_in_word = len(real_without_punctuation) | |
number_of_phonemes += number_of_phonemes_in_word | |
current_words_pronunciation_accuracy.append(float( | |
number_of_phonemes_in_word-number_of_word_mismatches)/number_of_phonemes_in_word*100) | |
percentage_of_correct_pronunciations = ( | |
number_of_phonemes-total_mismatches)/number_of_phonemes*100 | |
return np.round(percentage_of_correct_pronunciations), current_words_pronunciation_accuracy | |
def removePunctuation(self, word: str) -> str: | |
""" | |
Remove punctuation from a word. | |
Args: | |
word (str): The input word. | |
Returns: | |
str: The word without punctuation. | |
""" | |
return ''.join([char for char in word if char not in punctuation]) | |
def getWordsPronunciationCategory(self, accuracies) -> list: | |
""" | |
Get the pronunciation category for each word based on accuracy. | |
Args: | |
accuracies (list): A list of pronunciation accuracies. | |
Returns: | |
list: A list of pronunciation categories. | |
""" | |
categories = [] | |
for accuracy in accuracies: | |
categories.append( | |
self.getPronunciationCategoryFromAccuracy(accuracy)) | |
return categories | |
def getPronunciationCategoryFromAccuracy(self, accuracy) -> int: | |
""" | |
Get the pronunciation category based on accuracy. | |
Args: | |
accuracy (float): The pronunciation accuracy. | |
Returns: | |
int: The pronunciation category. | |
""" | |
return np.argmin(abs(self.categories_thresholds-accuracy)) | |
def preprocessAudio(self, audio: torch.tensor) -> torch.tensor: | |
""" | |
Preprocess the audio by normalizing it. | |
Args: | |
audio (torch.tensor): The input audio tensor. | |
Returns: | |
torch.tensor: The normalized audio tensor. | |
""" | |
return preprocessAudioStandalone(audio) | |
def getTrainer(language: str, model_name: str = MODEL_NAME_DEFAULT) -> PronunciationTrainer: | |
""" | |
Get a PronunciationTrainer instance for the specified language and model. | |
Args: | |
language (str): The language of the model. | |
model_name (str, optional): The name of the model. Defaults to MODEL_NAME_DEFAULT. | |
Returns: | |
PronunciationTrainer: An instance of PronunciationTrainer. | |
""" | |
asr_model = mo.getASRModel(language, model_name=model_name) | |
if language == 'de': | |
phonem_converter = RuleBasedModels.EpitranPhonemConverter(epitran.Epitran('deu-Latn')) | |
elif language == 'en': | |
phonem_converter = RuleBasedModels.EngPhonemConverter() | |
else: | |
raise ValueError(f"Language '{language}' not implemented") | |
trainer = PronunciationTrainer(asr_model, phonem_converter) | |
return trainer | |