bel32123's picture
Add Wav2Vec ASR Model Files
6d3dc99
raw
history blame
1.41 kB
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from speechbrain.pretrained import GraphemeToPhoneme
import datasets
import os
import torchaudio
from MispronounciationDetector import MispronounciationDetector
# Load sample data
audio_path, transcript_path = os.path.join(os.getcwd(), "data", "arctic_a0003.wav"), os.path.join(os.getcwd(), "data", "arctic_a0003.txt")
audio, org_sr = torchaudio.load(audio_path)
audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000)
audio = audio.view(audio.shape[1])
with open(transcript_path) as f:
text = f.read()
f.close()
print("Done loading sample data")
# Load processors and models
device = "cpu"
path = os.path.join(os.getcwd(), "model", "checkpoint-1200")
model = Wav2Vec2ForCTC.from_pretrained(path).to(device)
processor = Wav2Vec2Processor.from_pretrained(path)
g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
mispronounciation_detector = MispronounciationDetector(model, processor, g2p, "cpu")
print("Done loading models and processors")
# Predict
raw_info = mispronounciation_detector.detect(audio, text)
aligned_phoneme_output_delimited_by_words = " ".join(raw_info['ref']) + "\n" + " ".join(raw_info['hyp']) + "\n" +\
" ".join(raw_info['phoneme_errors'])
print(f"PER: {raw_info['per']}\n")
print(f"Phoneme level errors:\n{raw_info['phoneme_output']}\n")