Yehor commited on
Commit
8133fdc
·
verified ·
1 Parent(s): 1f30811

Update run_demo.py

Browse files
Files changed (1) hide show
  1. run_demo.py +5 -6
run_demo.py CHANGED
@@ -12,7 +12,7 @@ device = "cuda:0" # cuda:0, or cpu
12
  torch_dtype = torch.float16
13
  sampling_rate = 16_000
14
 
15
- model_name = "Yehor/mHuBERT-147-uk"
16
  testset_file = "examples.csv"
17
 
18
  # Load the test dataset
@@ -29,23 +29,22 @@ asr_model = HubertForCTC.from_pretrained(
29
  processor = Wav2Vec2Processor.from_pretrained(model_name)
30
 
31
 
32
- # A util function to make batches
33
  def make_batches(iterable, n=1):
34
  lx = len(iterable)
35
  for ndx in range(0, lx, n):
36
  yield iterable[ndx : min(ndx + n, lx)]
37
 
38
 
39
- # Temporary variables
40
  predictions_all = []
41
  references_all = []
42
 
43
- # Inference in the batched mode
44
  for batch in make_batches(samples, batch_size):
45
  paths = [it["path"] for it in batch]
46
  references = [it["text"] for it in batch]
47
 
48
- # Extract audio
49
  audio_inputs = []
50
  for path in paths:
51
  audio_input, sampling_rate = torchaudio.load(path, backend="sox")
@@ -53,7 +52,7 @@ for batch in make_batches(samples, batch_size):
53
 
54
  audio_inputs.append(audio_input)
55
 
56
- # Transcribe the audio
57
  inputs = processor(audio_inputs, sampling_rate=16_000, padding=True).input_values
58
 
59
  features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)
 
12
  torch_dtype = torch.float16
13
  sampling_rate = 16_000
14
 
15
+ model_name = "Yehor/hubert-uk"
16
  testset_file = "examples.csv"
17
 
18
  # Load the test dataset
 
29
  processor = Wav2Vec2Processor.from_pretrained(model_name)
30
 
31
 
32
+ # A func to make batches
33
  def make_batches(iterable, n=1):
34
  lx = len(iterable)
35
  for ndx in range(0, lx, n):
36
  yield iterable[ndx : min(ndx + n, lx)]
37
 
38
 
 
39
  predictions_all = []
40
  references_all = []
41
 
42
+ # Batched inference
43
  for batch in make_batches(samples, batch_size):
44
  paths = [it["path"] for it in batch]
45
  references = [it["text"] for it in batch]
46
 
47
+ # Extract audio features
48
  audio_inputs = []
49
  for path in paths:
50
  audio_input, sampling_rate = torchaudio.load(path, backend="sox")
 
52
 
53
  audio_inputs.append(audio_input)
54
 
55
+ # Transcribe
56
  inputs = processor(audio_inputs, sampling_rate=16_000, padding=True).input_values
57
 
58
  features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)