huckiyang commited on
Commit
9f029d4
·
1 Parent(s): d7d6438

more LM baseline

Browse files
Files changed (1) hide show
  1. app.py +32 -26
app.py CHANGED
@@ -404,24 +404,31 @@ def get_wer_metrics(dataset):
404
  "N-best Correction": np.nan
405
  }
406
 
407
- # Create a transposed DataFrame with metrics as rows and sources as columns
408
- metrics = ["Count", "No LM Baseline", "N-best LM Ranking", "N-best Correction"]
409
- result_df = pd.DataFrame(index=metrics, columns=["Metric"] + all_sources + ["OVERALL"])
410
 
411
- # Add descriptive column
412
- result_df["Metric"] = [
413
- "Number of Examples",
414
- "Word Error Rate (No LM)",
415
- "Word Error Rate (N-best LM Ranking)",
416
- "Word Error Rate (N-best Correction)"
417
- ]
 
 
 
418
 
419
  for source in all_sources + ["OVERALL"]:
420
- for metric in metrics:
421
- result_df.loc[metric, source] = source_results[source][metric]
 
 
 
 
 
422
 
423
- # Set Metric as index for better display
424
- result_df = result_df.set_index("Metric")
425
 
426
  return result_df
427
 
@@ -438,20 +445,19 @@ def format_dataframe(df):
438
 
439
  # Find the rows containing WER values
440
  wer_row_indices = []
441
- for idx in df.index:
442
- if "WER" in idx or "Error Rate" in idx:
443
- wer_row_indices.append(idx)
444
 
445
- for wer_row_index in wer_row_indices:
446
- # Convert to object type first to avoid warnings
447
- df.loc[wer_row_index] = df.loc[wer_row_index].astype(object)
448
-
449
  for col in df.columns:
450
- value = df.loc[wer_row_index, col]
451
- if pd.notna(value):
452
- df.loc[wer_row_index, col] = f"{value:.4f}"
453
- else:
454
- df.loc[wer_row_index, col] = "N/A"
 
455
 
456
  return df
457
 
 
404
  "N-best Correction": np.nan
405
  }
406
 
407
+ # Create flat DataFrame with labels in the first column
408
+ rows = []
 
409
 
410
+ # First add row for number of examples
411
+ example_row = {"Metric": "Number of Examples"}
412
+ for source in all_sources + ["OVERALL"]:
413
+ example_row[source] = source_results[source]["Count"]
414
+ rows.append(example_row)
415
+
416
+ # Then add rows for each WER method
417
+ no_lm_row = {"Metric": "Word Error Rate (No LM)"}
418
+ lm_ranking_row = {"Metric": "Word Error Rate (N-best LM Ranking)"}
419
+ n_best_row = {"Metric": "Word Error Rate (N-best Correction)"}
420
 
421
  for source in all_sources + ["OVERALL"]:
422
+ no_lm_row[source] = source_results[source]["No LM Baseline"]
423
+ lm_ranking_row[source] = source_results[source]["N-best LM Ranking"]
424
+ n_best_row[source] = source_results[source]["N-best Correction"]
425
+
426
+ rows.append(no_lm_row)
427
+ rows.append(lm_ranking_row)
428
+ rows.append(n_best_row)
429
 
430
+ # Create DataFrame from rows
431
+ result_df = pd.DataFrame(rows)
432
 
433
  return result_df
434
 
 
445
 
446
  # Find the rows containing WER values
447
  wer_row_indices = []
448
+ for i, metric in enumerate(df["Metric"]):
449
+ if "WER" in metric or "Error Rate" in metric:
450
+ wer_row_indices.append(i)
451
 
452
+ # Format WER values
453
+ for idx in wer_row_indices:
 
 
454
  for col in df.columns:
455
+ if col != "Metric": # Skip the metric column
456
+ value = df.loc[idx, col]
457
+ if pd.notna(value):
458
+ df.loc[idx, col] = f"{value:.4f}"
459
+ else:
460
+ df.loc[idx, col] = "N/A"
461
 
462
  return df
463