upd eval script to compute WER for each sample individually. upd fleurs predictions with WER column
Browse files
predictions/preds_google_fleurs_be_by_test_20221221-101048.tsv
DELETED
The diff for this file is too large to render.
See raw diff
|
|
predictions/preds_google_fleurs_be_by_test_20221221-101048.xlsx
ADDED
Binary file (242 kB). View file
|
|
src/run_eval_whisper_streaming.py
CHANGED
@@ -10,6 +10,7 @@ from transformers import pipeline
|
|
10 |
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
11 |
from datasets import load_dataset, Audio
|
12 |
import evaluate
|
|
|
13 |
|
14 |
from belarusian_text_normalizer import BelarusianTextNormalizer
|
15 |
|
@@ -33,6 +34,21 @@ wer_metric = evaluate.load("wer")
|
|
33 |
text_normalizer = BelarusianTextNormalizer()
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def is_target_text_in_range(ref):
|
37 |
if ref.strip() == "ignore time segment in scoring":
|
38 |
return False
|
@@ -106,15 +122,30 @@ def main(args):
|
|
106 |
logger.info(f'WER: {wer}')
|
107 |
|
108 |
if args.save_predictions is True:
|
109 |
-
preds_fp = f'preds_{args.dataset}_{args.config}_{args.split}_{now_str}.
|
110 |
preds_fp = clean_filename(preds_fp)
|
111 |
logger.info(f'saving predictions to: "{preds_fp}"')
|
|
|
112 |
preds_df = pd.DataFrame({
|
113 |
'audio_path': audio_paths,
|
114 |
'prediction_norm': predictions_norm, 'reference_norm': references_norm,
|
115 |
'prediction': predictions, 'reference': references,
|
116 |
})
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
else:
|
119 |
logger.info('save_predictions is False. will not save predictions to a file')
|
120 |
|
|
|
10 |
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
11 |
from datasets import load_dataset, Audio
|
12 |
import evaluate
|
13 |
+
import jiwer
|
14 |
|
15 |
from belarusian_text_normalizer import BelarusianTextNormalizer
|
16 |
|
|
|
34 |
text_normalizer = BelarusianTextNormalizer()
|
35 |
|
36 |
|
37 |
+
def pull_columns(df: pd.DataFrame, cols) -> pd.DataFrame:
|
38 |
+
""" Pull columns to the beginning of the dataframe """
|
39 |
+
if isinstance(cols, str):
|
40 |
+
cols = [cols]
|
41 |
+
cols = list(cols)
|
42 |
+
|
43 |
+
absent_cols = list(set(cols).difference(df.columns))
|
44 |
+
assert len(absent_cols) == 0, f'{absent_cols} columns are absent in df'
|
45 |
+
|
46 |
+
cols_rest = [c for c in df.columns if c not in cols]
|
47 |
+
new_df = df[cols + cols_rest].copy()
|
48 |
+
assert new_df.shape[1] == df.shape[1]
|
49 |
+
return new_df
|
50 |
+
|
51 |
+
|
52 |
def is_target_text_in_range(ref):
|
53 |
if ref.strip() == "ignore time segment in scoring":
|
54 |
return False
|
|
|
122 |
logger.info(f'WER: {wer}')
|
123 |
|
124 |
if args.save_predictions is True:
|
125 |
+
preds_fp = f'preds_{args.dataset}_{args.config}_{args.split}_{now_str}.xlsx'
|
126 |
preds_fp = clean_filename(preds_fp)
|
127 |
logger.info(f'saving predictions to: "{preds_fp}"')
|
128 |
+
|
129 |
preds_df = pd.DataFrame({
|
130 |
'audio_path': audio_paths,
|
131 |
'prediction_norm': predictions_norm, 'reference_norm': references_norm,
|
132 |
'prediction': predictions, 'reference': references,
|
133 |
})
|
134 |
+
|
135 |
+
logger.info('computing WER for each item individually')
|
136 |
+
preds_df['wer'] = preds_df.apply(
|
137 |
+
lambda row: 100 * jiwer.wer(
|
138 |
+
truth=row['reference_norm'], hypothesis=row['prediction_norm']),
|
139 |
+
axis=1
|
140 |
+
)
|
141 |
+
preds_df.sort_values('wer', ascending=False, inplace=True)
|
142 |
+
|
143 |
+
# use pull_columns instead of direct dataframe indexing
|
144 |
+
# not to delete any columns that could be added to dataframe in future.
|
145 |
+
cols_order = ['audio_path', 'wer', 'prediction_norm', 'reference_norm', 'prediction', 'reference']
|
146 |
+
preds_df = pull_columns(preds_df, cols=cols_order)
|
147 |
+
|
148 |
+
preds_df.to_excel(preds_fp, index=False)
|
149 |
else:
|
150 |
logger.info('save_predictions is False. will not save predictions to a file')
|
151 |
|