Update README.md
Browse filesUpdated the example script due to newer versions backward incompatibility.
README.md
CHANGED
@@ -26,11 +26,12 @@ Initial evaluation on partially noisy data showed the model to achieve a word er
|
|
26 |
|
27 |
```python
|
28 |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
29 |
-
from datasets import Audio
|
30 |
import soundfile as sf
|
31 |
import torch
|
32 |
import os
|
33 |
|
|
|
|
|
34 |
# load model and tokenizer
|
35 |
processor = Wav2Vec2Processor.from_pretrained(
|
36 |
"classla/wav2vec2-xls-r-parlaspeech-hr")
|
@@ -38,28 +39,23 @@ model = Wav2Vec2ForCTC.from_pretrained("classla/wav2vec2-xls-r-parlaspeech-hr")
|
|
38 |
|
39 |
|
40 |
# download the example wav files:
|
41 |
-
os.system("
|
42 |
|
43 |
-
# read the wav file
|
44 |
-
|
|
|
45 |
|
46 |
# remove the raw wav file
|
47 |
os.system("rm 00020570a.flac.wav")
|
48 |
|
49 |
-
# tokenize
|
50 |
-
input_values = processor(
|
51 |
-
audio["array"], return_tensors="pt", padding=True,
|
52 |
-
sampling_rate=16000).input_values
|
53 |
-
|
54 |
# retrieve logits
|
55 |
-
logits = model(input_values).logits
|
56 |
|
57 |
# take argmax and decode
|
58 |
predicted_ids = torch.argmax(logits, dim=-1)
|
59 |
-
transcription = processor.
|
60 |
-
|
61 |
|
62 |
-
# transcription:
|
63 |
```
|
64 |
|
65 |
## Training hyperparameters
|
|
|
26 |
|
27 |
```python
|
28 |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
|
|
29 |
import soundfile as sf
|
30 |
import torch
|
31 |
import os
|
32 |
|
33 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
34 |
+
|
35 |
# load model and tokenizer
|
36 |
processor = Wav2Vec2Processor.from_pretrained(
|
37 |
"classla/wav2vec2-xls-r-parlaspeech-hr")
|
|
|
39 |
|
40 |
|
41 |
# download the example wav files:
|
42 |
+
os.system("wget https://huggingface.co/classla/wav2vec2-xls-r-parlaspeech-hr/raw/main/00020570a.flac.wav")
|
43 |
|
44 |
+
# read the wav file
|
45 |
+
speech, sample_rate = sf.read("00020570a.flac.wav")
|
46 |
+
input_values = processor(speech, sampling_rate=sample_rate, return_tensors="pt").input_values.to(device)
|
47 |
|
48 |
# remove the raw wav file
|
49 |
os.system("rm 00020570a.flac.wav")
|
50 |
|
|
|
|
|
|
|
|
|
|
|
51 |
# retrieve logits
|
52 |
+
logits = model.to(device)(input_values).logits
|
53 |
|
54 |
# take argmax and decode
|
55 |
predicted_ids = torch.argmax(logits, dim=-1)
|
56 |
+
transcription = processor.decode(predicted_ids[0]).lower()
|
|
|
57 |
|
58 |
+
# transcription: 'veliki broj poslovnih subjekata posluje sa minusom velik dio'
|
59 |
```
|
60 |
|
61 |
## Training hyperparameters
|