Spaces:
Runtime error
Runtime error
import torch | |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, \ | |
Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor | |
import pyctcdecode | |
import json | |
import re | |
from sys import platform | |
class PhonemeASRModel: | |
def get_l2_phoneme_sequence(self, audio): | |
""" | |
:param audio: audio sampled at 16k sampling rate with torchaudio | |
:type audio: array | |
:return: predicted phonemes for L2 speaker | |
:rtype: array | |
""" | |
pass | |
def standardise_g2p_phoneme_sequence(self, phones): | |
""" | |
To facilitate mispronounciation detection | |
:param phones: native speaker phones predicted by G2P model | |
:type phones: array | |
:return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model | |
:rtype: array | |
""" | |
pass | |
def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): | |
""" | |
To facilitate testing | |
:param phones: native speaker phones as annotated in l2 artic | |
:type phones: array | |
:return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model | |
:rtype: array | |
""" | |
pass | |
class Wav2Vec2PhonemeASRModel(PhonemeASRModel): | |
""" | |
Uses greedy decoding | |
""" | |
def __init__(self, model_path, processor_path): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device) | |
self.processor = Wav2Vec2Processor.from_pretrained(processor_path) | |
def get_l2_phoneme_sequence(self, audio): | |
input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) | |
logits = self.model(input_dict.input_values.to(self.device)).logits | |
pred_ids = torch.argmax(logits, dim=-1)[0] | |
pred_phones = [phoneme for phoneme in self.processor.batch_decode(pred_ids) if phoneme != ""] | |
return pred_phones | |
def standardise_g2p_phoneme_sequence(self, phones): | |
return phones | |
def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): | |
return [re.sub(r'\d', "", phone_str) for phone_str in phones] | |
# TODO debug on linux because KenLM is not supported on Windows | |
class Wav2Vec2OptimisedPhonemeASRModel(PhonemeASRModel): | |
""" | |
Uses beam search and a LM for decoding | |
""" | |
def __init__(self, model_path, vocab_json_path, kenlm_model_path): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
f = open(vocab_json_path) | |
vocab_dict = json.load(f) | |
tokenizer = Wav2Vec2CTCTokenizer(vocab_json_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") | |
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, | |
do_normalize=True, return_attention_mask=False) | |
labels = list(vocab_dict.keys()) | |
# beam search | |
decoder = pyctcdecode.decoder.build_ctcdecoder(labels) | |
if (platform == "linux" or platform == "linux2") and kenlm_model_path: | |
# beam search + LM | |
decoder = pyctcdecode.decoder.build_ctcdecoder(labels, kenlm_model_path=kenlm_model_path) | |
self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device) | |
self.processor = Wav2Vec2ProcessorWithLM(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder) | |
def get_l2_phoneme_sequence(self, audio): | |
input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) | |
logits = self.model(input_dict.input_values.to(self.device)).logits.cpu().detach() | |
normalised_logits = torch.nn.Softmax(dim=2)(logits) | |
normalised_logits = normalised_logits.numpy()[0] | |
output = self.processor.decode(normalised_logits) | |
pred_phones = output.text.split(" ") | |
return pred_phones | |
def standardise_g2p_phoneme_sequence(self, phones): | |
return phones | |
def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): | |
return [re.sub(r'\d', "", phone_str) for phone_str in phones] | |