more LM baseline
Browse files
app.py
CHANGED
@@ -404,24 +404,31 @@ def get_wer_metrics(dataset):
|
|
404 |
"N-best Correction": np.nan
|
405 |
}
|
406 |
|
407 |
-
# Create
|
408 |
-
|
409 |
-
result_df = pd.DataFrame(index=metrics, columns=["Metric"] + all_sources + ["OVERALL"])
|
410 |
|
411 |
-
#
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
|
|
|
|
|
|
418 |
|
419 |
for source in all_sources + ["OVERALL"]:
|
420 |
-
|
421 |
-
|
|
|
|
|
|
|
|
|
|
|
422 |
|
423 |
-
#
|
424 |
-
result_df =
|
425 |
|
426 |
return result_df
|
427 |
|
@@ -438,20 +445,19 @@ def format_dataframe(df):
|
|
438 |
|
439 |
# Find the rows containing WER values
|
440 |
wer_row_indices = []
|
441 |
-
for
|
442 |
-
if "WER" in
|
443 |
-
wer_row_indices.append(
|
444 |
|
445 |
-
|
446 |
-
|
447 |
-
df.loc[wer_row_index] = df.loc[wer_row_index].astype(object)
|
448 |
-
|
449 |
for col in df.columns:
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
|
|
455 |
|
456 |
return df
|
457 |
|
|
|
404 |
"N-best Correction": np.nan
|
405 |
}
|
406 |
|
407 |
+
# Create flat DataFrame with labels in the first column
|
408 |
+
rows = []
|
|
|
409 |
|
410 |
+
# First add row for number of examples
|
411 |
+
example_row = {"Metric": "Number of Examples"}
|
412 |
+
for source in all_sources + ["OVERALL"]:
|
413 |
+
example_row[source] = source_results[source]["Count"]
|
414 |
+
rows.append(example_row)
|
415 |
+
|
416 |
+
# Then add rows for each WER method
|
417 |
+
no_lm_row = {"Metric": "Word Error Rate (No LM)"}
|
418 |
+
lm_ranking_row = {"Metric": "Word Error Rate (N-best LM Ranking)"}
|
419 |
+
n_best_row = {"Metric": "Word Error Rate (N-best Correction)"}
|
420 |
|
421 |
for source in all_sources + ["OVERALL"]:
|
422 |
+
no_lm_row[source] = source_results[source]["No LM Baseline"]
|
423 |
+
lm_ranking_row[source] = source_results[source]["N-best LM Ranking"]
|
424 |
+
n_best_row[source] = source_results[source]["N-best Correction"]
|
425 |
+
|
426 |
+
rows.append(no_lm_row)
|
427 |
+
rows.append(lm_ranking_row)
|
428 |
+
rows.append(n_best_row)
|
429 |
|
430 |
+
# Create DataFrame from rows
|
431 |
+
result_df = pd.DataFrame(rows)
|
432 |
|
433 |
return result_df
|
434 |
|
|
|
445 |
|
446 |
# Find the rows containing WER values
|
447 |
wer_row_indices = []
|
448 |
+
for i, metric in enumerate(df["Metric"]):
|
449 |
+
if "WER" in metric or "Error Rate" in metric:
|
450 |
+
wer_row_indices.append(i)
|
451 |
|
452 |
+
# Format WER values
|
453 |
+
for idx in wer_row_indices:
|
|
|
|
|
454 |
for col in df.columns:
|
455 |
+
if col != "Metric": # Skip the metric column
|
456 |
+
value = df.loc[idx, col]
|
457 |
+
if pd.notna(value):
|
458 |
+
df.loc[idx, col] = f"{value:.4f}"
|
459 |
+
else:
|
460 |
+
df.loc[idx, col] = "N/A"
|
461 |
|
462 |
return df
|
463 |
|