Yehor commited on
Commit
96269f9
1 Parent(s): c28b45a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +106 -1
README.md CHANGED
@@ -23,4 +23,109 @@ model-index:
23
  value: 0.2035
24
  ---
25
 
26
- # `HuBERT` for Ukrainian
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  value: 0.2035
24
  ---
25
 
26
+ # HuBERT for Ukrainian
27
+
28
+
29
+ ## Community
30
+
31
+ - Discord: https://discord.gg/yVAjkBgmt4
32
+ - Speech Recognition: https://t.me/speech_recognition_uk
33
+ - Speech Synthesis: https://t.me/speech_synthesis_uk
34
+
35
+ ## Usage
36
+
37
+ ```python
38
+ import csv
39
+
40
+ import torch
41
+ import torchaudio
42
+ import numpy as np
43
+ import evaluate
44
+
45
+ from transformers import HubertForCTC, Wav2Vec2Processor
46
+
47
+ batch_size = 8
48
+ device = "cuda:0" # or cpu
49
+ torch_dtype = torch.float16
50
+ sampling_rate = 16_000
51
+
52
+ model_name = "/home/yehor/ext-ml-disk/asr/hubert-training/models/final-85500"
53
+ testset_file = "/home/yehor/ext-ml-disk/asr/w2v2-bert-training/eval/rows_no_defis.csv"
54
+
55
+ # Load the test dataset
56
+ with open(testset_file) as f:
57
+ samples = list(csv.DictReader(f))
58
+
59
+ # Load the model
60
+ asr_model = HubertForCTC.from_pretrained(
61
+ model_name,
62
+ device_map=device,
63
+ torch_dtype=torch_dtype,
64
+ # attn_implementation="flash_attention_2",
65
+ )
66
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
67
+
68
+
69
+ # A util function to make batches
70
+ def make_batches(iterable, n=1):
71
+ lx = len(iterable)
72
+ for ndx in range(0, lx, n):
73
+ yield iterable[ndx : min(ndx + n, lx)]
74
+
75
+
76
+ # Temporary variables
77
+ predictions_all = []
78
+ references_all = []
79
+
80
+ # Inference in the batched mode
81
+ for batch in make_batches(samples, batch_size):
82
+ paths = [it["path"] for it in batch]
83
+ references = [it["text"] for it in batch]
84
+
85
+ # Extract audio
86
+ audio_inputs = []
87
+ for path in paths:
88
+ audio_input, sampling_rate = torchaudio.load(path, backend="sox")
89
+ audio_input = audio_input.squeeze(0).numpy()
90
+
91
+ audio_inputs.append(audio_input)
92
+
93
+ # Transcribe the audio
94
+ inputs = processor(audio_inputs, sampling_rate=16_000, padding=True).input_values
95
+
96
+ features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)
97
+
98
+ with torch.inference_mode():
99
+ logits = asr_model(features).logits
100
+
101
+ predicted_ids = torch.argmax(logits, dim=-1)
102
+ predictions = processor.batch_decode(predicted_ids)
103
+
104
+ # Log outputs
105
+ print("---")
106
+ print("Predictions:")
107
+ print(predictions)
108
+ print("References:")
109
+ print(references)
110
+ print("---")
111
+
112
+ # Add predictions and references
113
+ predictions_all.extend(predictions)
114
+ references_all.extend(references)
115
+
116
+ # Load evaluators
117
+ wer = evaluate.load("wer")
118
+ cer = evaluate.load("cer")
119
+
120
+ # Evaluate
121
+ wer_value = round(
122
+ wer.compute(predictions=predictions_all, references=references_all), 4
123
+ )
124
+ cer_value = round(
125
+ cer.compute(predictions=predictions_all, references=references_all), 4
126
+ )
127
+
128
+ # Print results
129
+ print("Final:")
130
+ print(f"WER: {wer_value} | CER: {cer_value}")
131
+ ```