ayameRushia commited on
Commit
9ae1879
1 Parent(s): 21337d8

Delete eval.py

Browse files
Files changed (1) hide show
  1. eval.py +0 -131
eval.py DELETED
@@ -1,131 +0,0 @@
1
- import re
2
- import argparse
3
- import unicodedata
4
- from typing import Dict
5
- import torch
6
- import torchaudio
7
- from datasets import load_dataset, load_metric, Audio, Dataset
8
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
9
- import re
10
-
11
- chars_to_ignore_regex = '[\é\!\,\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\’\—\–\·]'
12
-
13
- def log_results(result: Dataset, args: Dict[str, str]):
14
- """ DO NOT CHANGE. This function computes and logs the result metrics. """
15
-
16
- log_outputs = args.log_outputs
17
- dataset_id = "_".join(args.dataset.split("/") + [args.config, args.split])
18
-
19
- # load metric
20
- wer = load_metric("wer")
21
- cer = load_metric("cer")
22
-
23
- # compute metrics
24
- wer_result = wer.compute(references=result["sentence"], predictions=result["pred_strings"])
25
- cer_result = cer.compute(references=result["sentence"], predictions=result["pred_strings"])
26
-
27
- # print & log results
28
- result_str = (
29
- f"WER: {wer_result}\n"
30
- f"CER: {cer_result}"
31
- )
32
- print(result_str)
33
-
34
- with open(f"{dataset_id}_eval_results.txt", "w") as f:
35
- f.write(result_str)
36
-
37
- # log all results in text file. Possibly interesting for analysis
38
- if log_outputs is not None:
39
- pred_file = f"log_{dataset_id}_predictions.txt"
40
- target_file = f"log_{dataset_id}_targets.txt"
41
-
42
- with open(pred_file, "w") as p, open(target_file, "w") as t:
43
-
44
- # mapping function to write output
45
- def write_to_file(batch, i):
46
- p.write(f"{i}" + "\n")
47
- p.write(batch["pred_strings"] + "\n")
48
- t.write(f"{i}" + "\n")
49
- t.write(batch["sentence"] + "\n")
50
-
51
- result.map(write_to_file, with_indices=True)
52
-
53
- def load_data(dataset_id, language, split='test'):
54
- test_dataset = load_dataset(dataset_id, language, split=split, use_auth_token=True)
55
- test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=16_000))
56
- return test_dataset
57
-
58
- def speech_file_to_array_fn(batch):
59
- batch["sentence"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).lower() + " "
60
- batch["sentence"] = re.sub('!', '', batch["sentence"]).lower() + " "
61
- batch["sentence"] = batch["sentence"].replace('\"',"").replace("&","").replace("'","").replace("(","").lower() + " "
62
- batch["sentence"] = batch["sentence"].replace('[',"").replace("]","").replace("\\","").replace("«","").replace("»","").replace(")","").lower() + " "
63
- batch["sentence"] = batch["sentence"].replace(" "," ").replace(" "," ").replace(" "," ").lower() + " "
64
- batch["speech"] = batch["audio"]["array"]
65
- return batch
66
-
67
- def main(args):
68
- test_dataset = load_data(args.dataset, args.config, args.split)
69
- test_dataset = test_dataset.map(speech_file_to_array_fn)
70
- model_id = args.model_id
71
-
72
- def evaluate_with_lm(batch):
73
- inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
74
- with torch.no_grad():
75
- logits = model(**inputs.to('cuda')).logits
76
- int_result = processor.batch_decode(logits.cpu().numpy())
77
- batch["pred_strings"] = int_result.text
78
- return batch
79
-
80
- def evaluate(batch):
81
- inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
82
- with torch.no_grad():
83
- logits = model(inputs.input_values.to('cuda')).logits
84
- pred_ids = torch.argmax(logits, dim=-1)
85
- batch["pred_strings"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
86
- return batch
87
-
88
- if args.lm:
89
- processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_id,use_auth_token=True)
90
- model = Wav2Vec2ForCTC.from_pretrained(model_id,use_auth_token=True)
91
- model.to('cuda')
92
- result = test_dataset.map(evaluate_with_lm, batched=True, batch_size=4)
93
- else:
94
- processor = Wav2Vec2Processor.from_pretrained(model_id,use_auth_token=True)
95
- model = Wav2Vec2ForCTC.from_pretrained(model_id,use_auth_token=True)
96
- model.to("cuda")
97
- result = test_dataset.map(evaluate, batched=True, batch_size=4)
98
-
99
- log_results(result, args)
100
-
101
-
102
- if __name__ == "__main__":
103
- parser = argparse.ArgumentParser()
104
-
105
- parser.add_argument(
106
- "--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers"
107
- )
108
- parser.add_argument(
109
- "--dataset", type=str, required=True, help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets"
110
- )
111
- parser.add_argument(
112
- "--config", type=str, required=True, help="Config of the dataset. *E.g.* `'en'` for Common Voice"
113
- )
114
- parser.add_argument(
115
- "--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`"
116
- )
117
- parser.add_argument(
118
- "--chunk_length_s", type=float, default=None, help="Chunk length in seconds. Defaults to None. For long audio files a good value would be 5.0 seconds."
119
- )
120
- parser.add_argument(
121
- "--stride_length_s", type=float, default=None, help="Stride of the audio chunks. Defaults to None. For long audio files a good value would be 1.0 seconds."
122
- )
123
- parser.add_argument(
124
- "--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
125
- )
126
- parser.add_argument(
127
- "--lm", action='store_true', help="Using language model for evaluation or not."
128
- )
129
- args = parser.parse_args()
130
-
131
- main(args)