AlvaroMros commited on
Commit
eb615ca
·
1 Parent(s): f972c61

Add k-fold cross-validation to prediction pipeline

Browse files

Introduces a --kfold argument to main.py to enable 3-fold cross-validation. Implements run_kfold_cv in pipeline.py, using event-based splits and MLflow for experiment tracking and model registration. Refactors imports and typing for consistency, and moves configuration constants to config.py for better modularity.

src/predict/main.py CHANGED
@@ -55,6 +55,11 @@ def main():
55
  default=False,
56
  help="Force retrain all models even if no new data is available."
57
  )
 
 
 
 
 
58
  args = parser.parse_args()
59
 
60
  # Handle conflicting arguments
@@ -75,9 +80,15 @@ def main():
75
  use_existing_models=use_existing_models,
76
  force_retrain=force_retrain
77
  )
78
-
79
  try:
80
- pipeline.run(detailed_report=(args.report == 'detailed'))
 
 
 
 
81
  except FileNotFoundError as e:
82
  print(f"Error: {e}")
83
  print("Please ensure the required data files exist. You may need to run the scraping and ELO analysis first.")
 
 
 
 
55
  default=False,
56
  help="Force retrain all models even if no new data is available."
57
  )
58
+ parser.add_argument(
59
+ '--kfold',
60
+ action='store_true',
61
+ help='Run 3-fold CV instead of standard split.'
62
+ )
63
  args = parser.parse_args()
64
 
65
  # Handle conflicting arguments
 
80
  use_existing_models=use_existing_models,
81
  force_retrain=force_retrain
82
  )
 
83
  try:
84
+ if args.kfold:
85
+ cv_results = pipeline.run_kfold_cv(k=3, holdout_events=1)
86
+ print(cv_results)
87
+ else:
88
+ pipeline.run(detailed_report=(args.report == 'detailed'))
89
  except FileNotFoundError as e:
90
  print(f"Error: {e}")
91
  print("Please ensure the required data files exist. You may need to run the scraping and ELO analysis first.")
92
+
93
+ if __name__ == '__main__':
94
+ main()
src/predict/models.py CHANGED
@@ -12,7 +12,8 @@ from lightgbm import LGBMClassifier
12
  from ..analysis.elo import process_fights_for_elo, INITIAL_ELO
13
  from ..config import FIGHTERS_CSV_PATH
14
  from .preprocess import preprocess_for_ml, _get_fighter_history_stats
15
- from .utils import calculate_age, prepare_fighters_data, DEFAULT_ELO
 
16
 
17
  class BaseModel(ABC):
18
  """
@@ -87,7 +88,7 @@ class BaseMLModel(BaseModel):
87
  self.fighters_df = None
88
  self.fighter_histories = {}
89
 
90
- def train(self, train_fights: List[Dict[str, Any]]) -> None:
91
  """
92
  Trains the machine learning model. This involves loading fighter data,
93
  pre-calculating histories, and fitting the model on the preprocessed data.
 
12
  from ..analysis.elo import process_fights_for_elo, INITIAL_ELO
13
  from ..config import FIGHTERS_CSV_PATH
14
  from .preprocess import preprocess_for_ml, _get_fighter_history_stats
15
+ from .utils import calculate_age, prepare_fighters_data
16
+ from .config import DEFAULT_ELO
17
 
18
  class BaseModel(ABC):
19
  """
 
88
  self.fighters_df = None
89
  self.fighter_histories = {}
90
 
91
+ def train(self, train_fights: list[dict[str, any]]) -> None:
92
  """
93
  Trains the machine learning model. This involves loading fighter data,
94
  pre-calculating histories, and fitting the model on the preprocessed data.
src/predict/pipeline.py CHANGED
@@ -25,6 +25,9 @@ import json
25
  import joblib
26
  from ..config import FIGHTS_CSV_PATH, MODEL_RESULTS_PATH, MODELS_DIR, LAST_EVENT_JSON_PATH
27
  from .models import BaseModel
 
 
 
28
 
29
  class PredictionPipeline:
30
  """
@@ -248,6 +251,72 @@ class PredictionPipeline:
248
  if should_retrain:
249
  self._train_and_save_models()
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  def update_models_if_new_data(self):
252
  """
253
  Checks for new data and retrains/saves all models on the full dataset if needed.
 
25
  import joblib
26
  from ..config import FIGHTS_CSV_PATH, MODEL_RESULTS_PATH, MODELS_DIR, LAST_EVENT_JSON_PATH
27
  from .models import BaseModel
28
+ from sklearn.model_selection import KFold
29
+ import mlflow
30
+ import mlflow.sklearn
31
 
32
  class PredictionPipeline:
33
  """
 
251
  if should_retrain:
252
  self._train_and_save_models()
253
 
254
+ def run_kfold_cv(self, k: int = 3, holdout_events: int = 1):
255
+ """Performs k-fold cross-validation where each fold is a set of events.
256
+ Within each fold, we keep the last *holdout_events* for testing."""
257
+ fights = self._load_fights()
258
+
259
+ # Build an ordered list of unique events
260
+ event_list = list(OrderedDict.fromkeys(f['event_name'] for f in fights))
261
+
262
+ # Initialize KFold splitter on events
263
+ kf = KFold(n_splits=k, shuffle=True, random_state=42)
264
+
265
+ all_fold_metrics = []
266
+ for fold_idx, (train_event_idx, test_event_idx) in enumerate(kf.split(event_list), start=1):
267
+ train_events = [event_list[i] for i in train_event_idx]
268
+
269
+ # Collect fights that belong to the training events
270
+ fold_fights = [f for f in fights if f['event_name'] in train_events]
271
+
272
+ # Inside this fold, reserve the last `holdout_events` events for testing
273
+ fold_events_ordered = list(OrderedDict.fromkeys(f['event_name'] for f in fold_fights))
274
+ test_events = fold_events_ordered[-holdout_events:]
275
+
276
+ train_set = [f for f in fold_fights if f['event_name'] not in test_events]
277
+ test_set = [f for f in fold_fights if f['event_name'] in test_events]
278
+
279
+ # Start an MLflow run for the current fold
280
+ mlflow.set_experiment("UFC_KFold_CV")
281
+ with mlflow.start_run(run_name=f"fold_{fold_idx}"):
282
+ # Log meta information about the fold
283
+ mlflow.log_param("fold", fold_idx)
284
+ mlflow.log_param("train_events", len(train_events))
285
+ mlflow.log_param("test_events", holdout_events)
286
+
287
+ fold_results = {}
288
+ for model in self.models:
289
+ model_name = model.__class__.__name__
290
+
291
+ # Train and evaluate
292
+ model.train(train_set)
293
+ correct = 0
294
+ total_fights = 0
295
+ for fight in test_set:
296
+ if fight['winner'] not in ["Draw", "NC", ""]:
297
+ prediction = model.predict(fight)
298
+ if prediction.get('winner') == fight['winner']:
299
+ correct += 1
300
+ total_fights += 1
301
+
302
+ acc = correct / total_fights if total_fights > 0 else 0.0
303
+ fold_results[model_name] = acc
304
+
305
+ # Log metrics and register model to appear in MLflow Models tab
306
+ mlflow.log_metric(f"accuracy_{model_name}", acc)
307
+ mlflow.log_metric(f"total_fights_{model_name}", total_fights)
308
+
309
+ # Register the model with MLflow to appear in Models tab
310
+ mlflow.sklearn.log_model(
311
+ model,
312
+ f"model_{model_name}",
313
+ registered_model_name=f"{model_name}_UFC_Model"
314
+ )
315
+
316
+ all_fold_metrics.append(fold_results)
317
+
318
+ return all_fold_metrics
319
+
320
  def update_models_if_new_data(self):
321
  """
322
  Checks for new data and retrains/saves all models on the full dataset if needed.
src/predict/preprocess.py CHANGED
@@ -1,22 +1,21 @@
1
  import pandas as pd
2
  import os
3
  from datetime import datetime
4
- from typing import Dict, List, Tuple, Any, Optional
5
- from ..config import FIGHTERS_CSV_PATH
6
  from .utils import (
7
  parse_round_time_to_seconds, parse_striking_stats, to_int_safe,
8
- calculate_age, prepare_fighters_data, DEFAULT_ELO, N_FIGHTS_HISTORY
9
  )
 
10
 
11
 
12
 
13
  def _get_fighter_history_stats(
14
  fighter_name: str,
15
  current_fight_date: datetime,
16
- fighter_history: List[Dict[str, Any]],
17
  fighters_df: pd.DataFrame,
18
  n: int = N_FIGHTS_HISTORY
19
- ) -> Dict[str, float]:
20
  """
21
  Calculates performance statistics for a fighter based on their last n fights.
22
  """
@@ -82,9 +81,9 @@ def _get_fighter_history_stats(
82
  }
83
 
84
  def preprocess_for_ml(
85
- fights_to_process: List[Dict[str, Any]],
86
  fighters_csv_path: str
87
- ) -> Tuple[pd.DataFrame, pd.Series, pd.DataFrame]:
88
  """
89
  Transforms raw fight and fighter data into a feature matrix (X) and target vector (y)
90
  suitable for a binary classification machine learning model.
@@ -135,8 +134,8 @@ def preprocess_for_ml(
135
  if isinstance(f2_stats, pd.DataFrame): f2_stats = f2_stats.iloc[0]
136
 
137
  # Calculate ages for both fighters
138
- f1_age = _calculate_age(f1_stats.get('dob'), fight['event_date'])
139
- f2_age = _calculate_age(f2_stats.get('dob'), fight['event_date'])
140
 
141
  # Get historical stats for both fighters
142
  f1_hist_stats = _get_fighter_history_stats(f1_name, fight['date_obj'], fighter_histories.get(f1_name, []), fighters_prepared)
 
1
  import pandas as pd
2
  import os
3
  from datetime import datetime
 
 
4
  from .utils import (
5
  parse_round_time_to_seconds, parse_striking_stats, to_int_safe,
6
+ calculate_age, prepare_fighters_data
7
  )
8
+ from .config import DEFAULT_ELO, N_FIGHTS_HISTORY
9
 
10
 
11
 
12
  def _get_fighter_history_stats(
13
  fighter_name: str,
14
  current_fight_date: datetime,
15
+ fighter_history: list[dict[str, any]],
16
  fighters_df: pd.DataFrame,
17
  n: int = N_FIGHTS_HISTORY
18
+ ) -> dict[str, float]:
19
  """
20
  Calculates performance statistics for a fighter based on their last n fights.
21
  """
 
81
  }
82
 
83
  def preprocess_for_ml(
84
+ fights_to_process: list[dict[str, any]],
85
  fighters_csv_path: str
86
+ ) -> tuple[pd.DataFrame, pd.Series, pd.DataFrame]:
87
  """
88
  Transforms raw fight and fighter data into a feature matrix (X) and target vector (y)
89
  suitable for a binary classification machine learning model.
 
134
  if isinstance(f2_stats, pd.DataFrame): f2_stats = f2_stats.iloc[0]
135
 
136
  # Calculate ages for both fighters
137
+ f1_age = calculate_age(f1_stats.get('dob'), fight['event_date'])
138
+ f2_age = calculate_age(f2_stats.get('dob'), fight['event_date'])
139
 
140
  # Get historical stats for both fighters
141
  f1_hist_stats = _get_fighter_history_stats(f1_name, fight['date_obj'], fighter_histories.get(f1_name, []), fighters_prepared)
src/predict/utils.py CHANGED
@@ -1,14 +1,8 @@
1
  import pandas as pd
2
- import os
3
  from datetime import datetime
4
- from typing import Optional, Dict, Any
5
 
6
- # Constants
7
- DEFAULT_ELO = 1500
8
- DEFAULT_AGE = 0
9
- DEFAULT_FIGHT_TIME = 0
10
- DEFAULT_ROUNDS_DURATION = 5 * 60 # 5 minutes per round
11
- N_FIGHTS_HISTORY = 5
12
 
13
  def clean_numeric_column(series: pd.Series) -> pd.Series:
14
  """A helper to clean string columns into numbers, handling errors."""
 
1
  import pandas as pd
 
2
  from datetime import datetime
3
+ from typing import Optional, Any
4
 
5
+ from .config import DEFAULT_ROUNDS_DURATION
 
 
 
 
 
6
 
7
  def clean_numeric_column(series: pd.Series) -> pd.Series:
8
  """A helper to clean string columns into numbers, handling errors."""