huckiyang commited on
Commit
3c6aeb7
·
1 Parent(s): fbba242

optz the data loading

Browse files
Files changed (1) hide show
  1. app.py +48 -34
app.py CHANGED
@@ -3,57 +3,64 @@ import pandas as pd
3
  from datasets import load_dataset
4
  import jiwer
5
  import numpy as np
 
6
 
7
- # Load the dataset
 
8
  def load_data():
9
- dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction")
10
- return dataset
11
 
12
  # Calculate WER for a group of examples
13
  def calculate_wer(examples):
14
  if not examples:
15
  return 0.0
16
 
17
- valid_pairs = []
 
 
 
18
 
19
- for ex in examples:
20
- # Get transcription and input1 fields
21
- transcription = ex.get("transcription")
22
- input1 = ex.get("input1")
23
-
24
- # Only include examples where both fields exist and are not empty
25
- if transcription and input1:
26
- valid_pairs.append((transcription.strip(), input1.strip()))
27
-
28
- # If no valid pairs were found, return NaN
29
  if not valid_pairs:
30
  return np.nan
31
 
32
- # Separate references and hypotheses
33
- references = [pair[0] for pair in valid_pairs]
34
- hypotheses = [pair[1] for pair in valid_pairs]
35
 
36
  # Calculate WER
37
- wer = jiwer.wer(references, hypotheses)
38
- return wer
39
 
40
  # Get WER metrics by source and split
41
  def get_wer_metrics(dataset):
42
- results = []
 
 
 
 
 
 
 
 
 
43
 
44
- # Get unique sources
45
- train_sources = set([ex["source"] for ex in dataset["train"]])
46
- test_sources = set([ex["source"] for ex in dataset["test"]])
47
- all_sources = sorted(list(train_sources.union(test_sources)))
 
48
 
49
- # Calculate WER for each source in train split
 
 
 
 
50
  for source in all_sources:
51
- train_examples = [ex for ex in dataset["train"] if ex["source"] == source]
52
- train_count = len(train_examples)
53
- train_wer = calculate_wer(train_examples) if train_count > 0 else np.nan
54
 
55
- test_examples = [ex for ex in dataset["test"] if ex["source"] == source]
56
  test_count = len(test_examples)
 
 
57
  test_wer = calculate_wer(test_examples) if test_count > 0 else np.nan
58
 
59
  results.append({
@@ -64,7 +71,7 @@ def get_wer_metrics(dataset):
64
  "Test WER": test_wer
65
  })
66
 
67
- # Add overall metrics
68
  train_wer = calculate_wer(dataset["train"])
69
  test_wer = calculate_wer(dataset["test"])
70
 
@@ -80,8 +87,16 @@ def get_wer_metrics(dataset):
80
 
81
  # Format the dataframe for display
82
  def format_dataframe(df):
83
- df["Train WER"] = df["Train WER"].apply(lambda x: f"{x:.4f}" if not pd.isna(x) else "N/A")
84
- df["Test WER"] = df["Test WER"].apply(lambda x: f"{x:.4f}" if not pd.isna(x) else "N/A")
 
 
 
 
 
 
 
 
85
  return df
86
 
87
  # Main function to create the leaderboard
@@ -89,8 +104,7 @@ def create_leaderboard():
89
  try:
90
  dataset = load_data()
91
  metrics_df = get_wer_metrics(dataset)
92
- formatted_df = format_dataframe(metrics_df)
93
- return formatted_df
94
  except Exception as e:
95
  return pd.DataFrame({"Error": [str(e)]})
96
 
 
3
  from datasets import load_dataset
4
  import jiwer
5
  import numpy as np
6
+ from functools import lru_cache
7
 
8
+ # Cache the dataset loading to avoid reloading on refresh
9
+ @lru_cache(maxsize=1)
10
  def load_data():
11
+ return load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction")
 
12
 
13
  # Calculate WER for a group of examples
14
  def calculate_wer(examples):
15
  if not examples:
16
  return 0.0
17
 
18
+ # Filter valid examples in a single pass
19
+ valid_pairs = [(ex.get("transcription", "").strip(), ex.get("input1", "").strip())
20
+ for ex in examples
21
+ if ex.get("transcription") and ex.get("input1")]
22
 
 
 
 
 
 
 
 
 
 
 
23
  if not valid_pairs:
24
  return np.nan
25
 
26
+ # Unzip the pairs in one operation
27
+ references, hypotheses = zip(*valid_pairs) if valid_pairs else ([], [])
 
28
 
29
  # Calculate WER
30
+ return jiwer.wer(references, hypotheses)
 
31
 
32
  # Get WER metrics by source and split
33
  def get_wer_metrics(dataset):
34
+ # Pre-process the data to avoid repeated filtering
35
+ train_by_source = {}
36
+ test_by_source = {}
37
+
38
+ # Group examples by source in a single pass for each split
39
+ for ex in dataset["train"]:
40
+ source = ex["source"]
41
+ if source not in train_by_source:
42
+ train_by_source[source] = []
43
+ train_by_source[source].append(ex)
44
 
45
+ for ex in dataset["test"]:
46
+ source = ex["source"]
47
+ if source not in test_by_source:
48
+ test_by_source[source] = []
49
+ test_by_source[source].append(ex)
50
 
51
+ # Get all unique sources
52
+ all_sources = sorted(set(train_by_source.keys()) | set(test_by_source.keys()))
53
+
54
+ # Calculate metrics for each source
55
+ results = []
56
  for source in all_sources:
57
+ train_examples = train_by_source.get(source, [])
58
+ test_examples = test_by_source.get(source, [])
 
59
 
60
+ train_count = len(train_examples)
61
  test_count = len(test_examples)
62
+
63
+ train_wer = calculate_wer(train_examples) if train_count > 0 else np.nan
64
  test_wer = calculate_wer(test_examples) if test_count > 0 else np.nan
65
 
66
  results.append({
 
71
  "Test WER": test_wer
72
  })
73
 
74
+ # Calculate overall metrics once
75
  train_wer = calculate_wer(dataset["train"])
76
  test_wer = calculate_wer(dataset["test"])
77
 
 
87
 
88
  # Format the dataframe for display
89
  def format_dataframe(df):
90
+ # Use vectorized operations instead of apply
91
+ df = df.copy()
92
+ mask = df["Train WER"].notna()
93
+ df.loc[mask, "Train WER"] = df.loc[mask, "Train WER"].map(lambda x: f"{x:.4f}")
94
+ df.loc[~mask, "Train WER"] = "N/A"
95
+
96
+ mask = df["Test WER"].notna()
97
+ df.loc[mask, "Test WER"] = df.loc[mask, "Test WER"].map(lambda x: f"{x:.4f}")
98
+ df.loc[~mask, "Test WER"] = "N/A"
99
+
100
  return df
101
 
102
  # Main function to create the leaderboard
 
104
  try:
105
  dataset = load_data()
106
  metrics_df = get_wer_metrics(dataset)
107
+ return format_dataframe(metrics_df)
 
108
  except Exception as e:
109
  return pd.DataFrame({"Error": [str(e)]})
110