Yehor commited on
Commit
9081c0b
·
verified ·
1 Parent(s): 96269f9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +11 -90
README.md CHANGED
@@ -32,100 +32,21 @@ model-index:
32
  - Speech Recognition: https://t.me/speech_recognition_uk
33
  - Speech Synthesis: https://t.me/speech_synthesis_uk
34
 
35
- ## Usage
36
-
37
- ```python
38
- import csv
39
-
40
- import torch
41
- import torchaudio
42
- import numpy as np
43
- import evaluate
44
-
45
- from transformers import HubertForCTC, Wav2Vec2Processor
46
-
47
- batch_size = 8
48
- device = "cuda:0" # or cpu
49
- torch_dtype = torch.float16
50
- sampling_rate = 16_000
51
-
52
- model_name = "/home/yehor/ext-ml-disk/asr/hubert-training/models/final-85500"
53
- testset_file = "/home/yehor/ext-ml-disk/asr/w2v2-bert-training/eval/rows_no_defis.csv"
54
-
55
- # Load the test dataset
56
- with open(testset_file) as f:
57
- samples = list(csv.DictReader(f))
58
-
59
- # Load the model
60
- asr_model = HubertForCTC.from_pretrained(
61
- model_name,
62
- device_map=device,
63
- torch_dtype=torch_dtype,
64
- # attn_implementation="flash_attention_2",
65
- )
66
- processor = Wav2Vec2Processor.from_pretrained(model_name)
67
-
68
-
69
- # A util function to make batches
70
- def make_batches(iterable, n=1):
71
- lx = len(iterable)
72
- for ndx in range(0, lx, n):
73
- yield iterable[ndx : min(ndx + n, lx)]
74
 
 
 
75
 
76
- # Temporary variables
77
- predictions_all = []
78
- references_all = []
79
 
80
- # Inference in the batched mode
81
- for batch in make_batches(samples, batch_size):
82
- paths = [it["path"] for it in batch]
83
- references = [it["text"] for it in batch]
84
 
85
- # Extract audio
86
- audio_inputs = []
87
- for path in paths:
88
- audio_input, sampling_rate = torchaudio.load(path, backend="sox")
89
- audio_input = audio_input.squeeze(0).numpy()
90
-
91
- audio_inputs.append(audio_input)
92
-
93
- # Transcribe the audio
94
- inputs = processor(audio_inputs, sampling_rate=16_000, padding=True).input_values
95
-
96
- features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)
97
-
98
- with torch.inference_mode():
99
- logits = asr_model(features).logits
100
-
101
- predicted_ids = torch.argmax(logits, dim=-1)
102
- predictions = processor.batch_decode(predicted_ids)
103
-
104
- # Log outputs
105
- print("---")
106
- print("Predictions:")
107
- print(predictions)
108
- print("References:")
109
- print(references)
110
- print("---")
111
-
112
- # Add predictions and references
113
- predictions_all.extend(predictions)
114
- references_all.extend(references)
115
-
116
- # Load evaluators
117
- wer = evaluate.load("wer")
118
- cer = evaluate.load("cer")
119
 
120
- # Evaluate
121
- wer_value = round(
122
- wer.compute(predictions=predictions_all, references=references_all), 4
123
- )
124
- cer_value = round(
125
- cer.compute(predictions=predictions_all, references=references_all), 4
126
- )
127
 
128
- # Print results
129
- print("Final:")
130
- print(f"WER: {wer_value} | CER: {cer_value}")
131
  ```
 
32
  - Speech Recognition: https://t.me/speech_recognition_uk
33
  - Speech Synthesis: https://t.me/speech_synthesis_uk
34
 
35
+ ## Install
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ ```text
38
+ uv venv --python 3.12
39
 
40
+ source .venv/bin/activate
 
 
41
 
42
+ uv pip install -r requirements.txt
 
 
 
43
 
44
+ # in development mode
45
+ uv pip install -r requirements-dev.txt
46
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ ## Usage
 
 
 
 
 
 
49
 
50
+ ```text
51
+ python run_demo.py
 
52
  ```