Tonic commited on
Commit
71db310
·
1 Parent(s): 924581c

adds local and remote training monitors to config

Browse files
scripts/training/train_gpt_oss.py CHANGED
@@ -19,6 +19,11 @@ except Exception: # pragma: no cover - optional import depending on TRL version
19
  DPOTrainer = None
20
  from datasets import load_dataset
21
  from pathlib import Path
 
 
 
 
 
22
 
23
  # Ensure project root and config package are importable for configs that do `from config...` imports
24
  project_root = Path(__file__).resolve().parents[2]
@@ -876,6 +881,23 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
876
  # Setup Trackio tracking
877
  trackio_client = setup_trackio_tracking(config)
878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
879
  # Create SFT configuration
880
  sft_config = create_sft_config(config, output_dir)
881
 
@@ -949,6 +971,10 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
949
  if "packing" in sft_params:
950
  sft_kwargs["packing"] = getattr(config, 'packing', False)
951
 
 
 
 
 
952
  # Remove any None values
953
  sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None}
954
 
@@ -959,7 +985,15 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
959
 
960
  # Start training
961
  print("Starting GPT-OSS training...")
962
- trainer.train()
 
 
 
 
 
 
 
 
963
 
964
  # Save model
965
  print("Saving trained model...")
@@ -970,6 +1004,18 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
970
  print("Pushing model to Hugging Face Hub...")
971
  trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking")
972
 
 
 
 
 
 
 
 
 
 
 
 
 
973
  print("GPT-OSS training completed successfully!")
974
 
975
  return trainer
 
19
  DPOTrainer = None
20
  from datasets import load_dataset
21
  from pathlib import Path
22
+ # Import monitoring utilities from project src for persistent logging
23
+ try:
24
+ from src.monitoring import create_monitor_from_config # type: ignore
25
+ except Exception:
26
+ create_monitor_from_config = None # type: ignore
27
 
28
  # Ensure project root and config package are importable for configs that do `from config...` imports
29
  project_root = Path(__file__).resolve().parents[2]
 
881
  # Setup Trackio tracking
882
  trackio_client = setup_trackio_tracking(config)
883
 
884
+ # Initialize project monitor (HF Datasets + Trackio Space if configured)
885
+ monitor = None
886
+ monitor_callback = None
887
+ if create_monitor_from_config is not None:
888
+ try:
889
+ monitor = create_monitor_from_config(config, experiment_name=experiment_name)
890
+ # Persist configuration immediately
891
+ try:
892
+ cfg_dict = {k: v for k, v in config.__dict__.items() if not k.startswith('_')}
893
+ monitor.log_config(cfg_dict)
894
+ except Exception:
895
+ pass
896
+ # Create callback for SFTTrainer
897
+ monitor_callback = monitor.create_monitoring_callback()
898
+ except Exception:
899
+ monitor = None
900
+
901
  # Create SFT configuration
902
  sft_config = create_sft_config(config, output_dir)
903
 
 
971
  if "packing" in sft_params:
972
  sft_kwargs["packing"] = getattr(config, 'packing', False)
973
 
974
+ # Attach monitoring callback if supported
975
+ if "callbacks" in sft_params:
976
+ sft_kwargs["callbacks"] = ([monitor_callback] if monitor_callback is not None else [])
977
+
978
  # Remove any None values
979
  sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None}
980
 
 
985
 
986
  # Start training
987
  print("Starting GPT-OSS training...")
988
+ try:
989
+ trainer.train()
990
+ finally:
991
+ # Ensure periodic metrics are flushed at the end even if interrupted
992
+ try:
993
+ if monitor is not None:
994
+ monitor._save_to_hf_dataset({'status': 'running'})
995
+ except Exception:
996
+ pass
997
 
998
  # Save model
999
  print("Saving trained model...")
 
1004
  print("Pushing model to Hugging Face Hub...")
1005
  trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking")
1006
 
1007
+ # Log training summary and close monitor
1008
+ try:
1009
+ if monitor is not None:
1010
+ summary = {
1011
+ 'output_dir': output_dir,
1012
+ 'model_name': getattr(config, 'model_name', 'unknown'),
1013
+ }
1014
+ monitor.log_training_summary(summary)
1015
+ monitor.close()
1016
+ except Exception:
1017
+ pass
1018
+
1019
  print("GPT-OSS training completed successfully!")
1020
 
1021
  return trainer
src/monitoring.py CHANGED
@@ -50,6 +50,11 @@ class SmolLM3Monitor:
50
  self.log_artifacts = log_artifacts
51
  self.log_metrics_enabled = log_metrics # Rename to avoid conflict
52
  self.log_config_enabled = log_config # Rename to avoid conflict
 
 
 
 
 
53
 
54
  # HF Datasets configuration
55
  self.hf_token = hf_token or os.environ.get('HF_TOKEN')
@@ -343,12 +348,12 @@ class SmolLM3Monitor:
343
 
344
  def log_configuration(self, config: Dict[str, Any]):
345
  """Log experiment configuration"""
346
- if not self.enable_tracking or not self.log_config_enabled:
347
  return
348
 
349
  try:
350
  # Log configuration as parameters
351
- if self.trackio_client:
352
  try:
353
  result = self.trackio_client.log_parameters(
354
  experiment_id=self.experiment_id,
@@ -390,7 +395,7 @@ class SmolLM3Monitor:
390
  - throughput, step_time, batch_size, seq_len
391
  - token_acc, train/gate_ortho, train/center, etc.
392
  """
393
- if not self.enable_tracking or not self.log_metrics_enabled:
394
  return
395
 
396
  try:
@@ -400,7 +405,7 @@ class SmolLM3Monitor:
400
  metrics['step'] = step
401
 
402
  # Log to Trackio (if available)
403
- if self.trackio_client:
404
  try:
405
  result = self.trackio_client.log_metrics(
406
  experiment_id=self.experiment_id,
@@ -418,8 +423,8 @@ class SmolLM3Monitor:
418
  # Store locally
419
  self.metrics_history.append(metrics)
420
 
421
- # Save to HF Dataset periodically
422
- if len(self.metrics_history) % 10 == 0: # Save every 10 metrics
423
  self._save_to_hf_dataset({'metrics': self.metrics_history})
424
 
425
  logger.debug("Metrics logged: %s", metrics)
@@ -429,7 +434,7 @@ class SmolLM3Monitor:
429
 
430
  def log_model_checkpoint(self, checkpoint_path: str, step: Optional[int] = None):
431
  """Log model checkpoint"""
432
- if not self.enable_tracking or not self.log_artifacts:
433
  return
434
 
435
  try:
@@ -441,7 +446,7 @@ class SmolLM3Monitor:
441
  "checkpoint_size": os.path.getsize(checkpoint_path) if os.path.exists(checkpoint_path) else 0
442
  }
443
 
444
- if self.trackio_client:
445
  result = self.trackio_client.log_parameters(
446
  experiment_id=self.experiment_id,
447
  parameters=checkpoint_info
@@ -453,6 +458,11 @@ class SmolLM3Monitor:
453
  logger.error("Failed to log checkpoint to Trackio: %s", result)
454
 
455
  self.artifacts.append(checkpoint_path)
 
 
 
 
 
456
  logger.info("Checkpoint logged: %s", checkpoint_path)
457
 
458
  except Exception as e:
@@ -460,9 +470,6 @@ class SmolLM3Monitor:
460
 
461
  def log_evaluation_results(self, results: Dict[str, Any], step: Optional[int] = None):
462
  """Log evaluation results"""
463
- if not self.enable_tracking:
464
- return
465
-
466
  try:
467
  # Add evaluation prefix to metrics
468
  eval_metrics = {f"eval_{k}": v for k, v in results.items()}
@@ -485,9 +492,6 @@ class SmolLM3Monitor:
485
 
486
  def log_system_metrics(self, step: Optional[int] = None):
487
  """Log system metrics (GPU, memory, etc.)"""
488
- if not self.enable_tracking:
489
- return
490
-
491
  try:
492
  system_metrics = {}
493
 
@@ -513,9 +517,6 @@ class SmolLM3Monitor:
513
 
514
  def log_training_summary(self, summary: Dict[str, Any]):
515
  """Log training summary at the end"""
516
- if not self.enable_tracking:
517
- return
518
-
519
  try:
520
  # Add experiment duration
521
  end_time = datetime.now()
@@ -524,7 +525,7 @@ class SmolLM3Monitor:
524
  summary['experiment_duration_hours'] = duration / 3600
525
 
526
  # Log final summary to Trackio
527
- if self.trackio_client:
528
  result = self.trackio_client.log_parameters(
529
  experiment_id=self.experiment_id,
530
  parameters=summary
 
50
  self.log_artifacts = log_artifacts
51
  self.log_metrics_enabled = log_metrics # Rename to avoid conflict
52
  self.log_config_enabled = log_config # Rename to avoid conflict
53
+ # Flush interval for dataset persistence (metrics)
54
+ try:
55
+ self.flush_interval = int(os.environ.get('TRACKIO_FLUSH_INTERVAL', '10'))
56
+ except Exception:
57
+ self.flush_interval = 10
58
 
59
  # HF Datasets configuration
60
  self.hf_token = hf_token or os.environ.get('HF_TOKEN')
 
348
 
349
  def log_configuration(self, config: Dict[str, Any]):
350
  """Log experiment configuration"""
351
+ if not self.log_config_enabled:
352
  return
353
 
354
  try:
355
  # Log configuration as parameters
356
+ if self.enable_tracking and self.trackio_client:
357
  try:
358
  result = self.trackio_client.log_parameters(
359
  experiment_id=self.experiment_id,
 
395
  - throughput, step_time, batch_size, seq_len
396
  - token_acc, train/gate_ortho, train/center, etc.
397
  """
398
+ if not self.log_metrics_enabled:
399
  return
400
 
401
  try:
 
405
  metrics['step'] = step
406
 
407
  # Log to Trackio (if available)
408
+ if self.enable_tracking and self.trackio_client:
409
  try:
410
  result = self.trackio_client.log_metrics(
411
  experiment_id=self.experiment_id,
 
423
  # Store locally
424
  self.metrics_history.append(metrics)
425
 
426
+ # Save to HF Dataset periodically (configurable)
427
+ if self.flush_interval > 0 and (len(self.metrics_history) % self.flush_interval == 0):
428
  self._save_to_hf_dataset({'metrics': self.metrics_history})
429
 
430
  logger.debug("Metrics logged: %s", metrics)
 
434
 
435
  def log_model_checkpoint(self, checkpoint_path: str, step: Optional[int] = None):
436
  """Log model checkpoint"""
437
+ if not self.log_artifacts:
438
  return
439
 
440
  try:
 
446
  "checkpoint_size": os.path.getsize(checkpoint_path) if os.path.exists(checkpoint_path) else 0
447
  }
448
 
449
+ if self.enable_tracking and self.trackio_client:
450
  result = self.trackio_client.log_parameters(
451
  experiment_id=self.experiment_id,
452
  parameters=checkpoint_info
 
458
  logger.error("Failed to log checkpoint to Trackio: %s", result)
459
 
460
  self.artifacts.append(checkpoint_path)
461
+ # Also preserve checkpoint info in HF dataset
462
+ try:
463
+ self._save_to_hf_dataset({'artifacts': [checkpoint_path], **checkpoint_info})
464
+ except Exception:
465
+ pass
466
  logger.info("Checkpoint logged: %s", checkpoint_path)
467
 
468
  except Exception as e:
 
470
 
471
  def log_evaluation_results(self, results: Dict[str, Any], step: Optional[int] = None):
472
  """Log evaluation results"""
 
 
 
473
  try:
474
  # Add evaluation prefix to metrics
475
  eval_metrics = {f"eval_{k}": v for k, v in results.items()}
 
492
 
493
  def log_system_metrics(self, step: Optional[int] = None):
494
  """Log system metrics (GPU, memory, etc.)"""
 
 
 
495
  try:
496
  system_metrics = {}
497
 
 
517
 
518
  def log_training_summary(self, summary: Dict[str, Any]):
519
  """Log training summary at the end"""
 
 
 
520
  try:
521
  # Add experiment duration
522
  end_time = datetime.now()
 
525
  summary['experiment_duration_hours'] = duration / 3600
526
 
527
  # Log final summary to Trackio
528
+ if self.enable_tracking and self.trackio_client:
529
  result = self.trackio_client.log_parameters(
530
  experiment_id=self.experiment_id,
531
  parameters=summary
src/trainer.py CHANGED
@@ -78,6 +78,7 @@ class SmolLM3Trainer:
78
  # Add simple console callback for basic monitoring
79
  from transformers import TrainerCallback
80
 
 
81
  class SimpleConsoleCallback(TrainerCallback):
82
  def on_init_end(self, args, state, control, **kwargs):
83
  """Called when training initialization is complete"""
@@ -99,6 +100,16 @@ class SmolLM3Trainer:
99
  else:
100
  lr_str = str(lr)
101
  print(f"Step {step}: loss={loss_str}, lr={lr_str}")
 
 
 
 
 
 
 
 
 
 
102
 
103
  def on_train_begin(self, args, state, control, **kwargs):
104
  print("🚀 Training started!")
@@ -109,28 +120,40 @@ class SmolLM3Trainer:
109
  def on_save(self, args, state, control, **kwargs):
110
  step = state.global_step if hasattr(state, 'global_step') else 'unknown'
111
  print(f"💾 Checkpoint saved at step {step}")
 
 
 
 
 
 
 
112
 
113
  def on_evaluate(self, args, state, control, metrics=None, **kwargs):
114
  if metrics and isinstance(metrics, dict):
115
  step = state.global_step if hasattr(state, 'global_step') else 'unknown'
116
  eval_loss = metrics.get('eval_loss', 'N/A')
117
  print(f"📊 Evaluation at step {step}: eval_loss={eval_loss}")
 
 
 
 
 
118
 
119
  # Add console callback
120
  callbacks.append(SimpleConsoleCallback())
121
  logger.info("Added simple console monitoring callback")
122
 
123
- # Add Trackio callback if available
124
- if self.monitor and self.monitor.enable_tracking:
125
  try:
126
  trackio_callback = self.monitor.create_monitoring_callback()
127
  if trackio_callback:
128
  callbacks.append(trackio_callback)
129
- logger.info("Added Trackio monitoring callback")
130
  else:
131
- logger.warning("Failed to create Trackio callback")
132
  except Exception as e:
133
- logger.error("Error creating Trackio callback: %s", e)
134
  logger.info("Continuing with console monitoring only")
135
 
136
  logger.info("Total callbacks: %d", len(callbacks))
@@ -220,16 +243,20 @@ class SmolLM3Trainer:
220
  """Start training"""
221
  logger.info("Starting training")
222
 
223
- # Log configuration to Trackio
224
- if self.monitor and self.monitor.enable_tracking:
225
- config_dict = {k: v for k, v in self.config.__dict__.items()
226
- if not k.startswith('_')}
227
- self.monitor.log_config(config_dict)
228
-
229
- # Log experiment URL
230
- experiment_url = self.monitor.get_experiment_url()
231
- if experiment_url:
232
- logger.info("Trackio experiment URL: %s", experiment_url)
 
 
 
 
233
 
234
  # Load checkpoint if resuming
235
  if self.init_from == "resume":
@@ -251,17 +278,20 @@ class SmolLM3Trainer:
251
  with open(os.path.join(self.output_dir, "train_results.json"), "w") as f:
252
  json.dump(train_result.metrics, f, indent=2)
253
 
254
- # Log training summary to Trackio
255
- if self.monitor and self.monitor.enable_tracking:
256
- summary = {
257
- 'final_loss': train_result.metrics.get('train_loss', 0),
258
- 'total_steps': train_result.metrics.get('train_runtime', 0),
259
- 'training_time': train_result.metrics.get('train_runtime', 0),
260
- 'output_dir': self.output_dir,
261
- 'model_name': getattr(self.config, 'model_name', 'unknown'),
262
- }
263
- self.monitor.log_training_summary(summary)
264
- self.monitor.close()
 
 
 
265
 
266
  # Finish trackio experiment
267
  try:
@@ -276,9 +306,12 @@ class SmolLM3Trainer:
276
 
277
  except Exception as e:
278
  logger.error("Training failed: %s", e)
279
- # Close monitoring on error
280
- if self.monitor and self.monitor.enable_tracking:
281
- self.monitor.close()
 
 
 
282
 
283
  # Finish trackio experiment on error
284
  try:
 
78
  # Add simple console callback for basic monitoring
79
  from transformers import TrainerCallback
80
 
81
+ outer = self
82
  class SimpleConsoleCallback(TrainerCallback):
83
  def on_init_end(self, args, state, control, **kwargs):
84
  """Called when training initialization is complete"""
 
100
  else:
101
  lr_str = str(lr)
102
  print(f"Step {step}: loss={loss_str}, lr={lr_str}")
103
+
104
+ # Persist metrics via our monitor when Trackio callback isn't active
105
+ try:
106
+ if outer.monitor:
107
+ # Avoid double logging when Trackio callback is used
108
+ if not outer.monitor.enable_tracking:
109
+ outer.monitor.log_metrics(dict(logs), step if isinstance(step, int) else None)
110
+ outer.monitor.log_system_metrics(step if isinstance(step, int) else None)
111
+ except Exception as e:
112
+ logger.warning("SimpleConsoleCallback metrics persistence failed: %s", e)
113
 
114
  def on_train_begin(self, args, state, control, **kwargs):
115
  print("🚀 Training started!")
 
120
  def on_save(self, args, state, control, **kwargs):
121
  step = state.global_step if hasattr(state, 'global_step') else 'unknown'
122
  print(f"💾 Checkpoint saved at step {step}")
123
+ try:
124
+ if outer.monitor and not outer.monitor.enable_tracking:
125
+ checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{step}")
126
+ if os.path.exists(checkpoint_path):
127
+ outer.monitor.log_model_checkpoint(checkpoint_path, step if isinstance(step, int) else None)
128
+ except Exception as e:
129
+ logger.warning("SimpleConsoleCallback checkpoint persistence failed: %s", e)
130
 
131
  def on_evaluate(self, args, state, control, metrics=None, **kwargs):
132
  if metrics and isinstance(metrics, dict):
133
  step = state.global_step if hasattr(state, 'global_step') else 'unknown'
134
  eval_loss = metrics.get('eval_loss', 'N/A')
135
  print(f"📊 Evaluation at step {step}: eval_loss={eval_loss}")
136
+ try:
137
+ if outer.monitor and not outer.monitor.enable_tracking:
138
+ outer.monitor.log_evaluation_results(dict(metrics), step if isinstance(step, int) else None)
139
+ except Exception as e:
140
+ logger.warning("SimpleConsoleCallback eval persistence failed: %s", e)
141
 
142
  # Add console callback
143
  callbacks.append(SimpleConsoleCallback())
144
  logger.info("Added simple console monitoring callback")
145
 
146
+ # Add monitoring callback if available (always attach; it persists to dataset even if Trackio is disabled)
147
+ if self.monitor:
148
  try:
149
  trackio_callback = self.monitor.create_monitoring_callback()
150
  if trackio_callback:
151
  callbacks.append(trackio_callback)
152
+ logger.info("Added monitoring callback")
153
  else:
154
+ logger.warning("Failed to create monitoring callback")
155
  except Exception as e:
156
+ logger.error("Error creating monitoring callback: %s", e)
157
  logger.info("Continuing with console monitoring only")
158
 
159
  logger.info("Total callbacks: %d", len(callbacks))
 
243
  """Start training"""
244
  logger.info("Starting training")
245
 
246
+ # Log configuration (always persist to dataset; Trackio if enabled)
247
+ if self.monitor:
248
+ try:
249
+ config_dict = {k: v for k, v in self.config.__dict__.items() if not k.startswith('_')}
250
+ self.monitor.log_config(config_dict)
251
+ except Exception as e:
252
+ logger.warning("Failed to log configuration: %s", e)
253
+ # Log experiment URL only if available
254
+ try:
255
+ experiment_url = self.monitor.get_experiment_url()
256
+ if experiment_url:
257
+ logger.info("Trackio experiment URL: %s", experiment_url)
258
+ except Exception:
259
+ pass
260
 
261
  # Load checkpoint if resuming
262
  if self.init_from == "resume":
 
278
  with open(os.path.join(self.output_dir, "train_results.json"), "w") as f:
279
  json.dump(train_result.metrics, f, indent=2)
280
 
281
+ # Log training summary (always persist to dataset; Trackio if enabled)
282
+ if self.monitor:
283
+ try:
284
+ summary = {
285
+ 'final_loss': train_result.metrics.get('train_loss', 0),
286
+ 'total_steps': train_result.metrics.get('train_runtime', 0),
287
+ 'training_time': train_result.metrics.get('train_runtime', 0),
288
+ 'output_dir': self.output_dir,
289
+ 'model_name': getattr(self.config, 'model_name', 'unknown'),
290
+ }
291
+ self.monitor.log_training_summary(summary)
292
+ self.monitor.close()
293
+ except Exception as e:
294
+ logger.warning("Failed to log training summary: %s", e)
295
 
296
  # Finish trackio experiment
297
  try:
 
306
 
307
  except Exception as e:
308
  logger.error("Training failed: %s", e)
309
+ # Close monitoring on error (still persist final status to dataset)
310
+ if self.monitor:
311
+ try:
312
+ self.monitor.close(final_status="failed")
313
+ except Exception:
314
+ pass
315
 
316
  # Finish trackio experiment on error
317
  try: