elsayedissa commited on
Commit
94065aa
·
1 Parent(s): 75154fc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +60 -2
README.md CHANGED
@@ -58,7 +58,7 @@ The following hyperparameters were used during training:
58
  | 0.1361 | 0.56 | 4000 | 0.2372 | 0.1330 |
59
  | 0.1211 | 0.69 | 5000 | 0.2297 | 0.1282 |
60
 
61
- ### Transcription
62
 
63
  ```python
64
  from datasets import load_dataset, Audio
@@ -79,7 +79,7 @@ commonvoice_eval = commonvoice_eval.cast_column("audio", Audio(sampling_rate=160
79
  sample = next(iter(commonvoice_eval))["audio"]
80
 
81
  # features and generate token ids
82
- input_features = processor(sample["array"], sampling_rate=input_speech["sampling_rate"], return_tensors="pt").input_features
83
  predicted_ids = model.generate(input_features.to(device), forced_decoder_ids=forced_decoder_ids)
84
 
85
  # decode
@@ -88,6 +88,64 @@ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
88
 
89
  print(transcription)
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  ```
92
 
93
  ### Framework versions
 
58
  | 0.1361 | 0.56 | 4000 | 0.2372 | 0.1330 |
59
  | 0.1211 | 0.69 | 5000 | 0.2297 | 0.1282 |
60
 
61
+ ### Transcription:
62
 
63
  ```python
64
  from datasets import load_dataset, Audio
 
79
  sample = next(iter(commonvoice_eval))["audio"]
80
 
81
  # features and generate token ids
82
+ input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
83
  predicted_ids = model.generate(input_features.to(device), forced_decoder_ids=forced_decoder_ids)
84
 
85
  # decode
 
88
 
89
  print(transcription)
90
 
91
+ ```
92
+
93
+ ### Evaluation:
94
+
95
+ Evaluates this model on `mozilla-foundation/common_voice_11_0` test split.
96
+
97
+ ```python
98
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
99
+ from datasets import load_dataset, Audio
100
+ import evaluate
101
+ import torch
102
+ import re
103
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
104
+
105
+ # device
106
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107
+
108
+ # metric
109
+ wer_metric = evaluate.load("wer")
110
+
111
+ # model
112
+ processor = WhisperProcessor.from_pretrained("clu-ling/whisper-large-v2-spanish-5k-steps")
113
+ model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-large-v2-spanish-5k-steps")
114
+
115
+ # dataset
116
+ dataset = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="test", )#cache_dir=args.cache_dir
117
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
118
+
119
+ #for debuggings: it gets some examples
120
+ #dataset = dataset.shard(num_shards=10000, index=0)
121
+ #print(dataset)
122
+
123
+ def normalize(batch):
124
+ """Normalizes GOLD"""
125
+ batch["gold_text"] = whisper_norm(batch['sentence'])
126
+ return batch
127
+
128
+ def map_wer(batch):
129
+ model.to(args.device)
130
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language = "es", task = "transcribe")
131
+ inputs = processor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt").input_features
132
+ with torch.no_grad():
133
+ generated_ids = model.generate(inputs=inputs.to(device), forced_decoder_ids=forced_decoder_ids)
134
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
135
+ batch["predicted_text"] = whisper_norm(transcription)
136
+ return batch
137
+
138
+ # process GOLD text
139
+ processed_dataset = dataset.map(normalize)
140
+ # get predictions
141
+ predicted_dataset = processed_dataset.map(map_wer)
142
+
143
+ # word error rate
144
+ wer = wer_metric.compute(references=predicted_dataset['gold_text'], predictions=predicted_dataset['predicted_text'])
145
+ wer = round(100 * wer, 2)
146
+ print("WER:", wer)
147
+
148
+
149
  ```
150
 
151
  ### Framework versions