Yehor commited on
Commit
5f72cc6
·
verified ·
1 Parent(s): 9081c0b

Create run_demo.py

Browse files
Files changed (1) hide show
  1. run_demo.py +93 -0
run_demo.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+
3
+ import torch
4
+ import torchaudio
5
+ import numpy as np
6
+ import evaluate
7
+
8
+ from transformers import HubertForCTC, Wav2Vec2Processor
9
+
10
+ batch_size = 8
11
+ device = "cuda:0" # or cpu
12
+ torch_dtype = torch.float16
13
+ sampling_rate = 16_000
14
+
15
+ model_name = "/home/yehor/ext-ml-disk/asr/hubert-training/models/final-85500"
16
+ testset_file = "/home/yehor/ext-ml-disk/asr/w2v2-bert-training/eval/rows_no_defis.csv"
17
+
18
+ # Load the test dataset
19
+ with open(testset_file) as f:
20
+ samples = list(csv.DictReader(f))
21
+
22
+ # Load the model
23
+ asr_model = HubertForCTC.from_pretrained(
24
+ model_name,
25
+ device_map=device,
26
+ torch_dtype=torch_dtype,
27
+ # attn_implementation="flash_attention_2",
28
+ )
29
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
30
+
31
+
32
+ # A util function to make batches
33
+ def make_batches(iterable, n=1):
34
+ lx = len(iterable)
35
+ for ndx in range(0, lx, n):
36
+ yield iterable[ndx : min(ndx + n, lx)]
37
+
38
+
39
+ # Temporary variables
40
+ predictions_all = []
41
+ references_all = []
42
+
43
+ # Inference in the batched mode
44
+ for batch in make_batches(samples, batch_size):
45
+ paths = [it["path"] for it in batch]
46
+ references = [it["text"] for it in batch]
47
+
48
+ # Extract audio
49
+ audio_inputs = []
50
+ for path in paths:
51
+ audio_input, sampling_rate = torchaudio.load(path, backend="sox")
52
+ audio_input = audio_input.squeeze(0).numpy()
53
+
54
+ audio_inputs.append(audio_input)
55
+
56
+ # Transcribe the audio
57
+ inputs = processor(audio_inputs, sampling_rate=16_000, padding=True).input_values
58
+
59
+ features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)
60
+
61
+ with torch.inference_mode():
62
+ logits = asr_model(features).logits
63
+
64
+ predicted_ids = torch.argmax(logits, dim=-1)
65
+ predictions = processor.batch_decode(predicted_ids)
66
+
67
+ # Log outputs
68
+ print("---")
69
+ print("Predictions:")
70
+ print(predictions)
71
+ print("References:")
72
+ print(references)
73
+ print("---")
74
+
75
+ # Add predictions and references
76
+ predictions_all.extend(predictions)
77
+ references_all.extend(references)
78
+
79
+ # Load evaluators
80
+ wer = evaluate.load("wer")
81
+ cer = evaluate.load("cer")
82
+
83
+ # Evaluate
84
+ wer_value = round(
85
+ wer.compute(predictions=predictions_all, references=references_all), 4
86
+ )
87
+ cer_value = round(
88
+ cer.compute(predictions=predictions_all, references=references_all), 4
89
+ )
90
+
91
+ # Print results
92
+ print("Final:")
93
+ print(f"WER: {wer_value} | CER: {cer_value}")