bel32123 commited on
Commit
b615647
1 Parent(s): 83aa7fc

Update demo to use MultitaskASRModel

Browse files
Files changed (1) hide show
  1. wav2vecasr/demo.py +13 -9
wav2vecasr/demo.py CHANGED
@@ -3,13 +3,17 @@ from speechbrain.pretrained import GraphemeToPhoneme
3
  import datasets
4
  import os
5
  import torchaudio
6
- from MispronounciationDetector import MispronounciationDetector
 
 
 
7
 
8
  # Load sample data
9
- audio_path, transcript_path = os.path.join(os.getcwd(), "data", "arctic_a0003.wav"), os.path.join(os.getcwd(), "data", "arctic_a0003.txt")
10
  audio, org_sr = torchaudio.load(audio_path)
11
  audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000)
12
  audio = audio.view(audio.shape[1])
 
13
  with open(transcript_path) as f:
14
  text = f.read()
15
  f.close()
@@ -17,16 +21,16 @@ print("Done loading sample data")
17
 
18
  # Load processors and models
19
  device = "cpu"
20
- path = os.path.join(os.getcwd(), "model", "checkpoint-1200")
21
- model = Wav2Vec2ForCTC.from_pretrained(path).to(device)
22
- processor = Wav2Vec2Processor.from_pretrained(path)
23
  g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
24
- mispronounciation_detector = MispronounciationDetector(model, processor, g2p, "cpu")
25
  print("Done loading models and processors")
26
 
27
  # Predict
28
  raw_info = mispronounciation_detector.detect(audio, text)
29
- aligned_phoneme_output_delimited_by_words = " ".join(raw_info['ref']) + "\n" + " ".join(raw_info['hyp']) + "\n" +\
30
- " ".join(raw_info['phoneme_errors'])
 
31
  print(f"PER: {raw_info['per']}\n")
32
- print(f"Phoneme level errors:\n{raw_info['phoneme_output']}\n")
 
3
  import datasets
4
  import os
5
  import torchaudio
6
+ from wav2vecasr.MispronounciationDetector import MispronounciationDetector
7
+ from wav2vecasr.PhonemeASRModel import Wav2Vec2PhonemeASRModel, Wav2Vec2OptimisedPhonemeASRModel, MultitaskPhonemeASRModel
8
+ import jiwer
9
+ import re
10
 
11
  # Load sample data
12
+ audio_path, transcript_path = os.path.join(os.getcwd(), "data", "arctic_a0003.wav"), os.path.join(os.getcwd(),"data", "arctic_a0003.txt")
13
  audio, org_sr = torchaudio.load(audio_path)
14
  audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000)
15
  audio = audio.view(audio.shape[1])
16
+ audio = audio.to("cpu")
17
  with open(transcript_path) as f:
18
  text = f.read()
19
  f.close()
 
21
 
22
  # Load processors and models
23
  device = "cpu"
24
+ path = os.path.join(os.getcwd(), "model", "multitask_best_ctc.pt")
25
+ vocab_path = os.path.join(os.getcwd(), "model", "vocab")
26
+ asr_model = MultitaskPhonemeASRModel(path, vocab_path, device)
27
  g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
28
+ mispronounciation_detector = MispronounciationDetector(asr_model, g2p, "cpu")
29
  print("Done loading models and processors")
30
 
31
  # Predict
32
  raw_info = mispronounciation_detector.detect(audio, text)
33
+ print(raw_info['ref'])
34
+ print(raw_info['hyp'])
35
+ print(raw_info['phoneme_errors'])
36
  print(f"PER: {raw_info['per']}\n")