ales commited on
Commit
a05646f
1 Parent(s): ed29fed

upd eval script to compute WER for each sample individually. upd fleurs predictions with WER column

Browse files
predictions/preds_google_fleurs_be_by_test_20221221-101048.tsv DELETED
The diff for this file is too large to render. See raw diff
 
predictions/preds_google_fleurs_be_by_test_20221221-101048.xlsx ADDED
Binary file (242 kB). View file
 
src/run_eval_whisper_streaming.py CHANGED
@@ -10,6 +10,7 @@ from transformers import pipeline
10
  from transformers.models.whisper.english_normalizer import BasicTextNormalizer
11
  from datasets import load_dataset, Audio
12
  import evaluate
 
13
 
14
  from belarusian_text_normalizer import BelarusianTextNormalizer
15
 
@@ -33,6 +34,21 @@ wer_metric = evaluate.load("wer")
33
  text_normalizer = BelarusianTextNormalizer()
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def is_target_text_in_range(ref):
37
  if ref.strip() == "ignore time segment in scoring":
38
  return False
@@ -106,15 +122,30 @@ def main(args):
106
  logger.info(f'WER: {wer}')
107
 
108
  if args.save_predictions is True:
109
- preds_fp = f'preds_{args.dataset}_{args.config}_{args.split}_{now_str}.tsv'
110
  preds_fp = clean_filename(preds_fp)
111
  logger.info(f'saving predictions to: "{preds_fp}"')
 
112
  preds_df = pd.DataFrame({
113
  'audio_path': audio_paths,
114
  'prediction_norm': predictions_norm, 'reference_norm': references_norm,
115
  'prediction': predictions, 'reference': references,
116
  })
117
- preds_df.to_csv(preds_fp, sep='\t', index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  else:
119
  logger.info('save_predictions is False. will not save predictions to a file')
120
 
 
10
  from transformers.models.whisper.english_normalizer import BasicTextNormalizer
11
  from datasets import load_dataset, Audio
12
  import evaluate
13
+ import jiwer
14
 
15
  from belarusian_text_normalizer import BelarusianTextNormalizer
16
 
 
34
  text_normalizer = BelarusianTextNormalizer()
35
 
36
 
37
+ def pull_columns(df: pd.DataFrame, cols) -> pd.DataFrame:
38
+ """ Pull columns to the beginning of the dataframe """
39
+ if isinstance(cols, str):
40
+ cols = [cols]
41
+ cols = list(cols)
42
+
43
+ absent_cols = list(set(cols).difference(df.columns))
44
+ assert len(absent_cols) == 0, f'{absent_cols} columns are absent in df'
45
+
46
+ cols_rest = [c for c in df.columns if c not in cols]
47
+ new_df = df[cols + cols_rest].copy()
48
+ assert new_df.shape[1] == df.shape[1]
49
+ return new_df
50
+
51
+
52
  def is_target_text_in_range(ref):
53
  if ref.strip() == "ignore time segment in scoring":
54
  return False
 
122
  logger.info(f'WER: {wer}')
123
 
124
  if args.save_predictions is True:
125
+ preds_fp = f'preds_{args.dataset}_{args.config}_{args.split}_{now_str}.xlsx'
126
  preds_fp = clean_filename(preds_fp)
127
  logger.info(f'saving predictions to: "{preds_fp}"')
128
+
129
  preds_df = pd.DataFrame({
130
  'audio_path': audio_paths,
131
  'prediction_norm': predictions_norm, 'reference_norm': references_norm,
132
  'prediction': predictions, 'reference': references,
133
  })
134
+
135
+ logger.info('computing WER for each item individually')
136
+ preds_df['wer'] = preds_df.apply(
137
+ lambda row: 100 * jiwer.wer(
138
+ truth=row['reference_norm'], hypothesis=row['prediction_norm']),
139
+ axis=1
140
+ )
141
+ preds_df.sort_values('wer', ascending=False, inplace=True)
142
+
143
+ # use pull_columns instead of direct dataframe indexing
144
+ # not to delete any columns that could be added to dataframe in future.
145
+ cols_order = ['audio_path', 'wer', 'prediction_norm', 'reference_norm', 'prediction', 'reference']
146
+ preds_df = pull_columns(preds_df, cols=cols_order)
147
+
148
+ preds_df.to_excel(preds_fp, index=False)
149
  else:
150
  logger.info('save_predictions is False. will not save predictions to a file')
151