import torch from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, \ Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor import pyctcdecode import json import re from sys import platform class PhonemeASRModel: def get_l2_phoneme_sequence(self, audio): """ :param audio: audio sampled at 16k sampling rate with torchaudio :type audio: array :return: predicted phonemes for L2 speaker :rtype: array """ pass def standardise_g2p_phoneme_sequence(self, phones): """ To facilitate mispronounciation detection :param phones: native speaker phones predicted by G2P model :type phones: array :return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model :rtype: array """ pass def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): """ To facilitate testing :param phones: native speaker phones as annotated in l2 artic :type phones: array :return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model :rtype: array """ pass class Wav2Vec2PhonemeASRModel(PhonemeASRModel): """ Uses greedy decoding """ def __init__(self, model_path, processor_path): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device) self.processor = Wav2Vec2Processor.from_pretrained(processor_path) def get_l2_phoneme_sequence(self, audio): input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) logits = self.model(input_dict.input_values.to(self.device)).logits pred_ids = torch.argmax(logits, dim=-1)[0] pred_phones = [phoneme for phoneme in self.processor.batch_decode(pred_ids) if phoneme != ""] return pred_phones def standardise_g2p_phoneme_sequence(self, phones): return phones def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): return [re.sub(r'\d', "", phone_str) for phone_str in phones] # TODO debug on linux because KenLM is not supported on Windows class Wav2Vec2OptimisedPhonemeASRModel(PhonemeASRModel): """ Uses beam search and a LM for decoding """ def __init__(self, model_path, vocab_json_path, kenlm_model_path): self.device = "cuda" if torch.cuda.is_available() else "cpu" f = open(vocab_json_path) vocab_dict = json.load(f) tokenizer = Wav2Vec2CTCTokenizer(vocab_json_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) labels = list(vocab_dict.keys()) # beam search decoder = pyctcdecode.decoder.build_ctcdecoder(labels) if (platform == "linux" or platform == "linux2") and kenlm_model_path: # beam search + LM decoder = pyctcdecode.decoder.build_ctcdecoder(labels, kenlm_model_path=kenlm_model_path) self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device) self.processor = Wav2Vec2ProcessorWithLM(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder) def get_l2_phoneme_sequence(self, audio): input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) logits = self.model(input_dict.input_values.to(self.device)).logits.cpu().detach() normalised_logits = torch.nn.Softmax(dim=2)(logits) normalised_logits = normalised_logits.numpy()[0] output = self.processor.decode(normalised_logits) pred_phones = output.text.split(" ") return pred_phones def standardise_g2p_phoneme_sequence(self, phones): return phones def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): return [re.sub(r'\d', "", phone_str) for phone_str in phones]