Update run_demo.py
Browse files- run_demo.py +5 -6
run_demo.py
CHANGED
@@ -12,7 +12,7 @@ device = "cuda:0" # cuda:0, or cpu
|
|
12 |
torch_dtype = torch.float16
|
13 |
sampling_rate = 16_000
|
14 |
|
15 |
-
model_name = "Yehor/
|
16 |
testset_file = "examples.csv"
|
17 |
|
18 |
# Load the test dataset
|
@@ -29,23 +29,22 @@ asr_model = HubertForCTC.from_pretrained(
|
|
29 |
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
30 |
|
31 |
|
32 |
-
# A
|
33 |
def make_batches(iterable, n=1):
|
34 |
lx = len(iterable)
|
35 |
for ndx in range(0, lx, n):
|
36 |
yield iterable[ndx : min(ndx + n, lx)]
|
37 |
|
38 |
|
39 |
-
# Temporary variables
|
40 |
predictions_all = []
|
41 |
references_all = []
|
42 |
|
43 |
-
#
|
44 |
for batch in make_batches(samples, batch_size):
|
45 |
paths = [it["path"] for it in batch]
|
46 |
references = [it["text"] for it in batch]
|
47 |
|
48 |
-
# Extract audio
|
49 |
audio_inputs = []
|
50 |
for path in paths:
|
51 |
audio_input, sampling_rate = torchaudio.load(path, backend="sox")
|
@@ -53,7 +52,7 @@ for batch in make_batches(samples, batch_size):
|
|
53 |
|
54 |
audio_inputs.append(audio_input)
|
55 |
|
56 |
-
# Transcribe
|
57 |
inputs = processor(audio_inputs, sampling_rate=16_000, padding=True).input_values
|
58 |
|
59 |
features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)
|
|
|
12 |
torch_dtype = torch.float16
|
13 |
sampling_rate = 16_000
|
14 |
|
15 |
+
model_name = "Yehor/hubert-uk"
|
16 |
testset_file = "examples.csv"
|
17 |
|
18 |
# Load the test dataset
|
|
|
29 |
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
30 |
|
31 |
|
32 |
+
# A func to make batches
|
33 |
def make_batches(iterable, n=1):
|
34 |
lx = len(iterable)
|
35 |
for ndx in range(0, lx, n):
|
36 |
yield iterable[ndx : min(ndx + n, lx)]
|
37 |
|
38 |
|
|
|
39 |
predictions_all = []
|
40 |
references_all = []
|
41 |
|
42 |
+
# Batched inference
|
43 |
for batch in make_batches(samples, batch_size):
|
44 |
paths = [it["path"] for it in batch]
|
45 |
references = [it["text"] for it in batch]
|
46 |
|
47 |
+
# Extract audio features
|
48 |
audio_inputs = []
|
49 |
for path in paths:
|
50 |
audio_input, sampling_rate = torchaudio.load(path, backend="sox")
|
|
|
52 |
|
53 |
audio_inputs.append(audio_input)
|
54 |
|
55 |
+
# Transcribe
|
56 |
inputs = processor(audio_inputs, sampling_rate=16_000, padding=True).input_values
|
57 |
|
58 |
features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)
|