Add better performing model
Browse files- README.md +19 -19
- config.json +2 -2
- pytorch_model.bin +1 -1
- vocab.json +1 -1
README.md
CHANGED
@@ -23,12 +23,12 @@ model-index:
|
|
23 |
metrics:
|
24 |
- name: Test WER
|
25 |
type: wer
|
26 |
-
value:
|
27 |
---
|
28 |
|
29 |
# Wav2Vec2-Large-XLSR-53-Kazakh
|
30 |
|
31 |
-
Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53)
|
32 |
|
33 |
When using this model, make sure that your speech input is sampled at 16kHz.
|
34 |
|
@@ -53,15 +53,15 @@ model = Wav2Vec2ForCTC.from_pretrained("wav2vec2-large-xlsr-kazakh")
|
|
53 |
# Preprocessing the datasets.
|
54 |
# We need to read the audio files as arrays
|
55 |
def speech_file_to_array_fn(batch):
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
|
60 |
test_dataset = test_dataset.map(speech_file_to_array_fn)
|
61 |
inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
|
62 |
|
63 |
with torch.no_grad():
|
64 |
-
|
65 |
|
66 |
predicted_ids = torch.argmax(logits, dim=-1)
|
67 |
|
@@ -72,7 +72,7 @@ print("Reference:", test_dataset["sentence"][:2])
|
|
72 |
|
73 |
## Evaluation
|
74 |
|
75 |
-
The model can be evaluated as follows on the test
|
76 |
|
77 |
```python
|
78 |
import torch
|
@@ -94,31 +94,31 @@ model.to("cuda")
|
|
94 |
# Preprocessing the datasets.
|
95 |
# We need to read the audio files as arrays
|
96 |
def speech_file_to_array_fn(batch):
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
|
102 |
test_dataset = test_dataset.map(speech_file_to_array_fn)
|
103 |
|
104 |
def evaluate(batch):
|
105 |
-
|
106 |
|
107 |
-
|
108 |
-
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
|
114 |
result = test_dataset.map(evaluate, batched=True, batch_size=8)
|
115 |
|
116 |
print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
|
117 |
```
|
118 |
|
119 |
-
**Test Result**:
|
120 |
|
121 |
|
122 |
## Training
|
123 |
|
124 |
-
The Kazakh Speech Corpus v1.1 `train` dataset was used for training
|
|
|
23 |
metrics:
|
24 |
- name: Test WER
|
25 |
type: wer
|
26 |
+
value: 19.65
|
27 |
---
|
28 |
|
29 |
# Wav2Vec2-Large-XLSR-53-Kazakh
|
30 |
|
31 |
+
Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) for Kazakh ASR using the [Kazakh Speech Corpus v1.1](https://issai.nu.edu.kz/kz-speech-corpus/?version=1.1)
|
32 |
|
33 |
When using this model, make sure that your speech input is sampled at 16kHz.
|
34 |
|
|
|
53 |
# Preprocessing the datasets.
|
54 |
# We need to read the audio files as arrays
|
55 |
def speech_file_to_array_fn(batch):
|
56 |
+
speech_array, sampling_rate = torchaudio.load(batch["path"])
|
57 |
+
batch["speech"] = torchaudio.transforms.Resample(sampling_rate, 16_000)(speech_array).squeeze().numpy()
|
58 |
+
return batch
|
59 |
|
60 |
test_dataset = test_dataset.map(speech_file_to_array_fn)
|
61 |
inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
|
62 |
|
63 |
with torch.no_grad():
|
64 |
+
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
|
65 |
|
66 |
predicted_ids = torch.argmax(logits, dim=-1)
|
67 |
|
|
|
72 |
|
73 |
## Evaluation
|
74 |
|
75 |
+
The model can be evaluated as follows on the test set of [Kazakh Speech Corpus v1.1](https://issai.nu.edu.kz/kz-speech-corpus/?version=1.1). To evaluate, download the [archive](https://www.openslr.org/resources/102/ISSAI_KSC_335RS_v1.1_flac.tar.gz), untar and pass the path to data to `get_test_dataset` as below:
|
76 |
|
77 |
```python
|
78 |
import torch
|
|
|
94 |
# Preprocessing the datasets.
|
95 |
# We need to read the audio files as arrays
|
96 |
def speech_file_to_array_fn(batch):
|
97 |
+
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower()
|
98 |
+
speech_array, sampling_rate = torchaudio.load(batch["path"])
|
99 |
+
batch["speech"] = torchaudio.transforms.Resample(sampling_rate, 16_000)(speech_array).squeeze().numpy()
|
100 |
+
return batch
|
101 |
|
102 |
test_dataset = test_dataset.map(speech_file_to_array_fn)
|
103 |
|
104 |
def evaluate(batch):
|
105 |
+
inputs = processor(batch["text"], sampling_rate=16_000, return_tensors="pt", padding=True)
|
106 |
|
107 |
+
with torch.no_grad():
|
108 |
+
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
|
109 |
|
110 |
+
pred_ids = torch.argmax(logits, dim=-1)
|
111 |
+
batch["pred_strings"] = processor.batch_decode(pred_ids)
|
112 |
+
return batch
|
113 |
|
114 |
result = test_dataset.map(evaluate, batched=True, batch_size=8)
|
115 |
|
116 |
print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
|
117 |
```
|
118 |
|
119 |
+
**Test Result**: 19.65%
|
120 |
|
121 |
|
122 |
## Training
|
123 |
|
124 |
+
The Kazakh Speech Corpus v1.1 `train` dataset was used for training.
|
config.json
CHANGED
@@ -42,7 +42,7 @@
|
|
42 |
"feat_extract_activation": "gelu",
|
43 |
"feat_extract_dropout": 0.0,
|
44 |
"feat_extract_norm": "layer",
|
45 |
-
"feat_proj_dropout": 0
|
46 |
"final_dropout": 0.0,
|
47 |
"gradient_checkpointing": true,
|
48 |
"hidden_act": "gelu",
|
@@ -62,7 +62,7 @@
|
|
62 |
"mask_time_length": 10,
|
63 |
"mask_time_min_space": 1,
|
64 |
"mask_time_other": 0.0,
|
65 |
-
"mask_time_prob": 0,
|
66 |
"mask_time_selection": "static",
|
67 |
"model_type": "wav2vec2",
|
68 |
"num_attention_heads": 16,
|
|
|
42 |
"feat_extract_activation": "gelu",
|
43 |
"feat_extract_dropout": 0.0,
|
44 |
"feat_extract_norm": "layer",
|
45 |
+
"feat_proj_dropout": 0,
|
46 |
"final_dropout": 0.0,
|
47 |
"gradient_checkpointing": true,
|
48 |
"hidden_act": "gelu",
|
|
|
62 |
"mask_time_length": 10,
|
63 |
"mask_time_min_space": 1,
|
64 |
"mask_time_other": 0.0,
|
65 |
+
"mask_time_prob": 0.05,
|
66 |
"mask_time_selection": "static",
|
67 |
"model_type": "wav2vec2",
|
68 |
"num_attention_heads": 16,
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1262118359
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b1e83db5ad4984dbe05208534fd2003de46257634517f054984a09c4d61a1ace
|
3 |
size 1262118359
|
vocab.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"
|
|
|
1 |
+
{"а": 1, "б": 2, "в": 3, "г": 4, "д": 5, "е": 6, "ж": 7, "з": 8, "и": 9, "й": 10, "к": 11, "л": 12, "м": 13, "н": 14, "о": 15, "п": 16, "р": 17, "с": 18, "т": 19, "у": 20, "ф": 21, "х": 22, "ц": 23, "ч": 24, "ш": 25, "щ": 26, "ъ": 27, "ы": 28, "ь": 29, "э": 30, "ю": 31, "я": 32, "ё": 33, "і": 34, "ғ": 35, "қ": 36, "ң": 37, "ү": 38, "ұ": 39, "һ": 40, "ә": 41, "ө": 42, "|": 0, "[UNK]": 43, "[PAD]": 44}
|