Spaces:
Runtime error
Runtime error
Update demo to use MultitaskASRModel
Browse files- wav2vecasr/demo.py +13 -9
wav2vecasr/demo.py
CHANGED
@@ -3,13 +3,17 @@ from speechbrain.pretrained import GraphemeToPhoneme
|
|
3 |
import datasets
|
4 |
import os
|
5 |
import torchaudio
|
6 |
-
from MispronounciationDetector import MispronounciationDetector
|
|
|
|
|
|
|
7 |
|
8 |
# Load sample data
|
9 |
-
audio_path, transcript_path = os.path.join(os.getcwd(), "data", "arctic_a0003.wav"), os.path.join(os.getcwd(),
|
10 |
audio, org_sr = torchaudio.load(audio_path)
|
11 |
audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000)
|
12 |
audio = audio.view(audio.shape[1])
|
|
|
13 |
with open(transcript_path) as f:
|
14 |
text = f.read()
|
15 |
f.close()
|
@@ -17,16 +21,16 @@ print("Done loading sample data")
|
|
17 |
|
18 |
# Load processors and models
|
19 |
device = "cpu"
|
20 |
-
path = os.path.join(os.getcwd(), "model", "
|
21 |
-
|
22 |
-
|
23 |
g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
|
24 |
-
mispronounciation_detector = MispronounciationDetector(
|
25 |
print("Done loading models and processors")
|
26 |
|
27 |
# Predict
|
28 |
raw_info = mispronounciation_detector.detect(audio, text)
|
29 |
-
|
30 |
-
|
|
|
31 |
print(f"PER: {raw_info['per']}\n")
|
32 |
-
print(f"Phoneme level errors:\n{raw_info['phoneme_output']}\n")
|
|
|
3 |
import datasets
|
4 |
import os
|
5 |
import torchaudio
|
6 |
+
from wav2vecasr.MispronounciationDetector import MispronounciationDetector
|
7 |
+
from wav2vecasr.PhonemeASRModel import Wav2Vec2PhonemeASRModel, Wav2Vec2OptimisedPhonemeASRModel, MultitaskPhonemeASRModel
|
8 |
+
import jiwer
|
9 |
+
import re
|
10 |
|
11 |
# Load sample data
|
12 |
+
audio_path, transcript_path = os.path.join(os.getcwd(), "data", "arctic_a0003.wav"), os.path.join(os.getcwd(),"data", "arctic_a0003.txt")
|
13 |
audio, org_sr = torchaudio.load(audio_path)
|
14 |
audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000)
|
15 |
audio = audio.view(audio.shape[1])
|
16 |
+
audio = audio.to("cpu")
|
17 |
with open(transcript_path) as f:
|
18 |
text = f.read()
|
19 |
f.close()
|
|
|
21 |
|
22 |
# Load processors and models
|
23 |
device = "cpu"
|
24 |
+
path = os.path.join(os.getcwd(), "model", "multitask_best_ctc.pt")
|
25 |
+
vocab_path = os.path.join(os.getcwd(), "model", "vocab")
|
26 |
+
asr_model = MultitaskPhonemeASRModel(path, vocab_path, device)
|
27 |
g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
|
28 |
+
mispronounciation_detector = MispronounciationDetector(asr_model, g2p, "cpu")
|
29 |
print("Done loading models and processors")
|
30 |
|
31 |
# Predict
|
32 |
raw_info = mispronounciation_detector.detect(audio, text)
|
33 |
+
print(raw_info['ref'])
|
34 |
+
print(raw_info['hyp'])
|
35 |
+
print(raw_info['phoneme_errors'])
|
36 |
print(f"PER: {raw_info['per']}\n")
|
|