Spaces:
Running
Running
adds local and remote training monitors to config
Browse files- scripts/training/train_gpt_oss.py +47 -1
- src/monitoring.py +19 -18
- src/trainer.py +62 -29
scripts/training/train_gpt_oss.py
CHANGED
@@ -19,6 +19,11 @@ except Exception: # pragma: no cover - optional import depending on TRL version
|
|
19 |
DPOTrainer = None
|
20 |
from datasets import load_dataset
|
21 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# Ensure project root and config package are importable for configs that do `from config...` imports
|
24 |
project_root = Path(__file__).resolve().parents[2]
|
@@ -876,6 +881,23 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
|
|
876 |
# Setup Trackio tracking
|
877 |
trackio_client = setup_trackio_tracking(config)
|
878 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
879 |
# Create SFT configuration
|
880 |
sft_config = create_sft_config(config, output_dir)
|
881 |
|
@@ -949,6 +971,10 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
|
|
949 |
if "packing" in sft_params:
|
950 |
sft_kwargs["packing"] = getattr(config, 'packing', False)
|
951 |
|
|
|
|
|
|
|
|
|
952 |
# Remove any None values
|
953 |
sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None}
|
954 |
|
@@ -959,7 +985,15 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
|
|
959 |
|
960 |
# Start training
|
961 |
print("Starting GPT-OSS training...")
|
962 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
963 |
|
964 |
# Save model
|
965 |
print("Saving trained model...")
|
@@ -970,6 +1004,18 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
|
|
970 |
print("Pushing model to Hugging Face Hub...")
|
971 |
trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking")
|
972 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
973 |
print("GPT-OSS training completed successfully!")
|
974 |
|
975 |
return trainer
|
|
|
19 |
DPOTrainer = None
|
20 |
from datasets import load_dataset
|
21 |
from pathlib import Path
|
22 |
+
# Import monitoring utilities from project src for persistent logging
|
23 |
+
try:
|
24 |
+
from src.monitoring import create_monitor_from_config # type: ignore
|
25 |
+
except Exception:
|
26 |
+
create_monitor_from_config = None # type: ignore
|
27 |
|
28 |
# Ensure project root and config package are importable for configs that do `from config...` imports
|
29 |
project_root = Path(__file__).resolve().parents[2]
|
|
|
881 |
# Setup Trackio tracking
|
882 |
trackio_client = setup_trackio_tracking(config)
|
883 |
|
884 |
+
# Initialize project monitor (HF Datasets + Trackio Space if configured)
|
885 |
+
monitor = None
|
886 |
+
monitor_callback = None
|
887 |
+
if create_monitor_from_config is not None:
|
888 |
+
try:
|
889 |
+
monitor = create_monitor_from_config(config, experiment_name=experiment_name)
|
890 |
+
# Persist configuration immediately
|
891 |
+
try:
|
892 |
+
cfg_dict = {k: v for k, v in config.__dict__.items() if not k.startswith('_')}
|
893 |
+
monitor.log_config(cfg_dict)
|
894 |
+
except Exception:
|
895 |
+
pass
|
896 |
+
# Create callback for SFTTrainer
|
897 |
+
monitor_callback = monitor.create_monitoring_callback()
|
898 |
+
except Exception:
|
899 |
+
monitor = None
|
900 |
+
|
901 |
# Create SFT configuration
|
902 |
sft_config = create_sft_config(config, output_dir)
|
903 |
|
|
|
971 |
if "packing" in sft_params:
|
972 |
sft_kwargs["packing"] = getattr(config, 'packing', False)
|
973 |
|
974 |
+
# Attach monitoring callback if supported
|
975 |
+
if "callbacks" in sft_params:
|
976 |
+
sft_kwargs["callbacks"] = ([monitor_callback] if monitor_callback is not None else [])
|
977 |
+
|
978 |
# Remove any None values
|
979 |
sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None}
|
980 |
|
|
|
985 |
|
986 |
# Start training
|
987 |
print("Starting GPT-OSS training...")
|
988 |
+
try:
|
989 |
+
trainer.train()
|
990 |
+
finally:
|
991 |
+
# Ensure periodic metrics are flushed at the end even if interrupted
|
992 |
+
try:
|
993 |
+
if monitor is not None:
|
994 |
+
monitor._save_to_hf_dataset({'status': 'running'})
|
995 |
+
except Exception:
|
996 |
+
pass
|
997 |
|
998 |
# Save model
|
999 |
print("Saving trained model...")
|
|
|
1004 |
print("Pushing model to Hugging Face Hub...")
|
1005 |
trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking")
|
1006 |
|
1007 |
+
# Log training summary and close monitor
|
1008 |
+
try:
|
1009 |
+
if monitor is not None:
|
1010 |
+
summary = {
|
1011 |
+
'output_dir': output_dir,
|
1012 |
+
'model_name': getattr(config, 'model_name', 'unknown'),
|
1013 |
+
}
|
1014 |
+
monitor.log_training_summary(summary)
|
1015 |
+
monitor.close()
|
1016 |
+
except Exception:
|
1017 |
+
pass
|
1018 |
+
|
1019 |
print("GPT-OSS training completed successfully!")
|
1020 |
|
1021 |
return trainer
|
src/monitoring.py
CHANGED
@@ -50,6 +50,11 @@ class SmolLM3Monitor:
|
|
50 |
self.log_artifacts = log_artifacts
|
51 |
self.log_metrics_enabled = log_metrics # Rename to avoid conflict
|
52 |
self.log_config_enabled = log_config # Rename to avoid conflict
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
# HF Datasets configuration
|
55 |
self.hf_token = hf_token or os.environ.get('HF_TOKEN')
|
@@ -343,12 +348,12 @@ class SmolLM3Monitor:
|
|
343 |
|
344 |
def log_configuration(self, config: Dict[str, Any]):
|
345 |
"""Log experiment configuration"""
|
346 |
-
if not self.
|
347 |
return
|
348 |
|
349 |
try:
|
350 |
# Log configuration as parameters
|
351 |
-
if self.trackio_client:
|
352 |
try:
|
353 |
result = self.trackio_client.log_parameters(
|
354 |
experiment_id=self.experiment_id,
|
@@ -390,7 +395,7 @@ class SmolLM3Monitor:
|
|
390 |
- throughput, step_time, batch_size, seq_len
|
391 |
- token_acc, train/gate_ortho, train/center, etc.
|
392 |
"""
|
393 |
-
if not self.
|
394 |
return
|
395 |
|
396 |
try:
|
@@ -400,7 +405,7 @@ class SmolLM3Monitor:
|
|
400 |
metrics['step'] = step
|
401 |
|
402 |
# Log to Trackio (if available)
|
403 |
-
if self.trackio_client:
|
404 |
try:
|
405 |
result = self.trackio_client.log_metrics(
|
406 |
experiment_id=self.experiment_id,
|
@@ -418,8 +423,8 @@ class SmolLM3Monitor:
|
|
418 |
# Store locally
|
419 |
self.metrics_history.append(metrics)
|
420 |
|
421 |
-
# Save to HF Dataset periodically
|
422 |
-
if len(self.metrics_history) %
|
423 |
self._save_to_hf_dataset({'metrics': self.metrics_history})
|
424 |
|
425 |
logger.debug("Metrics logged: %s", metrics)
|
@@ -429,7 +434,7 @@ class SmolLM3Monitor:
|
|
429 |
|
430 |
def log_model_checkpoint(self, checkpoint_path: str, step: Optional[int] = None):
|
431 |
"""Log model checkpoint"""
|
432 |
-
if not self.
|
433 |
return
|
434 |
|
435 |
try:
|
@@ -441,7 +446,7 @@ class SmolLM3Monitor:
|
|
441 |
"checkpoint_size": os.path.getsize(checkpoint_path) if os.path.exists(checkpoint_path) else 0
|
442 |
}
|
443 |
|
444 |
-
if self.trackio_client:
|
445 |
result = self.trackio_client.log_parameters(
|
446 |
experiment_id=self.experiment_id,
|
447 |
parameters=checkpoint_info
|
@@ -453,6 +458,11 @@ class SmolLM3Monitor:
|
|
453 |
logger.error("Failed to log checkpoint to Trackio: %s", result)
|
454 |
|
455 |
self.artifacts.append(checkpoint_path)
|
|
|
|
|
|
|
|
|
|
|
456 |
logger.info("Checkpoint logged: %s", checkpoint_path)
|
457 |
|
458 |
except Exception as e:
|
@@ -460,9 +470,6 @@ class SmolLM3Monitor:
|
|
460 |
|
461 |
def log_evaluation_results(self, results: Dict[str, Any], step: Optional[int] = None):
|
462 |
"""Log evaluation results"""
|
463 |
-
if not self.enable_tracking:
|
464 |
-
return
|
465 |
-
|
466 |
try:
|
467 |
# Add evaluation prefix to metrics
|
468 |
eval_metrics = {f"eval_{k}": v for k, v in results.items()}
|
@@ -485,9 +492,6 @@ class SmolLM3Monitor:
|
|
485 |
|
486 |
def log_system_metrics(self, step: Optional[int] = None):
|
487 |
"""Log system metrics (GPU, memory, etc.)"""
|
488 |
-
if not self.enable_tracking:
|
489 |
-
return
|
490 |
-
|
491 |
try:
|
492 |
system_metrics = {}
|
493 |
|
@@ -513,9 +517,6 @@ class SmolLM3Monitor:
|
|
513 |
|
514 |
def log_training_summary(self, summary: Dict[str, Any]):
|
515 |
"""Log training summary at the end"""
|
516 |
-
if not self.enable_tracking:
|
517 |
-
return
|
518 |
-
|
519 |
try:
|
520 |
# Add experiment duration
|
521 |
end_time = datetime.now()
|
@@ -524,7 +525,7 @@ class SmolLM3Monitor:
|
|
524 |
summary['experiment_duration_hours'] = duration / 3600
|
525 |
|
526 |
# Log final summary to Trackio
|
527 |
-
if self.trackio_client:
|
528 |
result = self.trackio_client.log_parameters(
|
529 |
experiment_id=self.experiment_id,
|
530 |
parameters=summary
|
|
|
50 |
self.log_artifacts = log_artifacts
|
51 |
self.log_metrics_enabled = log_metrics # Rename to avoid conflict
|
52 |
self.log_config_enabled = log_config # Rename to avoid conflict
|
53 |
+
# Flush interval for dataset persistence (metrics)
|
54 |
+
try:
|
55 |
+
self.flush_interval = int(os.environ.get('TRACKIO_FLUSH_INTERVAL', '10'))
|
56 |
+
except Exception:
|
57 |
+
self.flush_interval = 10
|
58 |
|
59 |
# HF Datasets configuration
|
60 |
self.hf_token = hf_token or os.environ.get('HF_TOKEN')
|
|
|
348 |
|
349 |
def log_configuration(self, config: Dict[str, Any]):
|
350 |
"""Log experiment configuration"""
|
351 |
+
if not self.log_config_enabled:
|
352 |
return
|
353 |
|
354 |
try:
|
355 |
# Log configuration as parameters
|
356 |
+
if self.enable_tracking and self.trackio_client:
|
357 |
try:
|
358 |
result = self.trackio_client.log_parameters(
|
359 |
experiment_id=self.experiment_id,
|
|
|
395 |
- throughput, step_time, batch_size, seq_len
|
396 |
- token_acc, train/gate_ortho, train/center, etc.
|
397 |
"""
|
398 |
+
if not self.log_metrics_enabled:
|
399 |
return
|
400 |
|
401 |
try:
|
|
|
405 |
metrics['step'] = step
|
406 |
|
407 |
# Log to Trackio (if available)
|
408 |
+
if self.enable_tracking and self.trackio_client:
|
409 |
try:
|
410 |
result = self.trackio_client.log_metrics(
|
411 |
experiment_id=self.experiment_id,
|
|
|
423 |
# Store locally
|
424 |
self.metrics_history.append(metrics)
|
425 |
|
426 |
+
# Save to HF Dataset periodically (configurable)
|
427 |
+
if self.flush_interval > 0 and (len(self.metrics_history) % self.flush_interval == 0):
|
428 |
self._save_to_hf_dataset({'metrics': self.metrics_history})
|
429 |
|
430 |
logger.debug("Metrics logged: %s", metrics)
|
|
|
434 |
|
435 |
def log_model_checkpoint(self, checkpoint_path: str, step: Optional[int] = None):
|
436 |
"""Log model checkpoint"""
|
437 |
+
if not self.log_artifacts:
|
438 |
return
|
439 |
|
440 |
try:
|
|
|
446 |
"checkpoint_size": os.path.getsize(checkpoint_path) if os.path.exists(checkpoint_path) else 0
|
447 |
}
|
448 |
|
449 |
+
if self.enable_tracking and self.trackio_client:
|
450 |
result = self.trackio_client.log_parameters(
|
451 |
experiment_id=self.experiment_id,
|
452 |
parameters=checkpoint_info
|
|
|
458 |
logger.error("Failed to log checkpoint to Trackio: %s", result)
|
459 |
|
460 |
self.artifacts.append(checkpoint_path)
|
461 |
+
# Also preserve checkpoint info in HF dataset
|
462 |
+
try:
|
463 |
+
self._save_to_hf_dataset({'artifacts': [checkpoint_path], **checkpoint_info})
|
464 |
+
except Exception:
|
465 |
+
pass
|
466 |
logger.info("Checkpoint logged: %s", checkpoint_path)
|
467 |
|
468 |
except Exception as e:
|
|
|
470 |
|
471 |
def log_evaluation_results(self, results: Dict[str, Any], step: Optional[int] = None):
|
472 |
"""Log evaluation results"""
|
|
|
|
|
|
|
473 |
try:
|
474 |
# Add evaluation prefix to metrics
|
475 |
eval_metrics = {f"eval_{k}": v for k, v in results.items()}
|
|
|
492 |
|
493 |
def log_system_metrics(self, step: Optional[int] = None):
|
494 |
"""Log system metrics (GPU, memory, etc.)"""
|
|
|
|
|
|
|
495 |
try:
|
496 |
system_metrics = {}
|
497 |
|
|
|
517 |
|
518 |
def log_training_summary(self, summary: Dict[str, Any]):
|
519 |
"""Log training summary at the end"""
|
|
|
|
|
|
|
520 |
try:
|
521 |
# Add experiment duration
|
522 |
end_time = datetime.now()
|
|
|
525 |
summary['experiment_duration_hours'] = duration / 3600
|
526 |
|
527 |
# Log final summary to Trackio
|
528 |
+
if self.enable_tracking and self.trackio_client:
|
529 |
result = self.trackio_client.log_parameters(
|
530 |
experiment_id=self.experiment_id,
|
531 |
parameters=summary
|
src/trainer.py
CHANGED
@@ -78,6 +78,7 @@ class SmolLM3Trainer:
|
|
78 |
# Add simple console callback for basic monitoring
|
79 |
from transformers import TrainerCallback
|
80 |
|
|
|
81 |
class SimpleConsoleCallback(TrainerCallback):
|
82 |
def on_init_end(self, args, state, control, **kwargs):
|
83 |
"""Called when training initialization is complete"""
|
@@ -99,6 +100,16 @@ class SmolLM3Trainer:
|
|
99 |
else:
|
100 |
lr_str = str(lr)
|
101 |
print(f"Step {step}: loss={loss_str}, lr={lr_str}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
def on_train_begin(self, args, state, control, **kwargs):
|
104 |
print("🚀 Training started!")
|
@@ -109,28 +120,40 @@ class SmolLM3Trainer:
|
|
109 |
def on_save(self, args, state, control, **kwargs):
|
110 |
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
|
111 |
print(f"💾 Checkpoint saved at step {step}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
114 |
if metrics and isinstance(metrics, dict):
|
115 |
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
|
116 |
eval_loss = metrics.get('eval_loss', 'N/A')
|
117 |
print(f"📊 Evaluation at step {step}: eval_loss={eval_loss}")
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
# Add console callback
|
120 |
callbacks.append(SimpleConsoleCallback())
|
121 |
logger.info("Added simple console monitoring callback")
|
122 |
|
123 |
-
# Add
|
124 |
-
if self.monitor
|
125 |
try:
|
126 |
trackio_callback = self.monitor.create_monitoring_callback()
|
127 |
if trackio_callback:
|
128 |
callbacks.append(trackio_callback)
|
129 |
-
logger.info("Added
|
130 |
else:
|
131 |
-
logger.warning("Failed to create
|
132 |
except Exception as e:
|
133 |
-
logger.error("Error creating
|
134 |
logger.info("Continuing with console monitoring only")
|
135 |
|
136 |
logger.info("Total callbacks: %d", len(callbacks))
|
@@ -220,16 +243,20 @@ class SmolLM3Trainer:
|
|
220 |
"""Start training"""
|
221 |
logger.info("Starting training")
|
222 |
|
223 |
-
# Log configuration to Trackio
|
224 |
-
if self.monitor
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
|
|
233 |
|
234 |
# Load checkpoint if resuming
|
235 |
if self.init_from == "resume":
|
@@ -251,17 +278,20 @@ class SmolLM3Trainer:
|
|
251 |
with open(os.path.join(self.output_dir, "train_results.json"), "w") as f:
|
252 |
json.dump(train_result.metrics, f, indent=2)
|
253 |
|
254 |
-
# Log training summary to Trackio
|
255 |
-
if self.monitor
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
|
|
|
|
|
|
265 |
|
266 |
# Finish trackio experiment
|
267 |
try:
|
@@ -276,9 +306,12 @@ class SmolLM3Trainer:
|
|
276 |
|
277 |
except Exception as e:
|
278 |
logger.error("Training failed: %s", e)
|
279 |
-
# Close monitoring on error
|
280 |
-
if self.monitor
|
281 |
-
|
|
|
|
|
|
|
282 |
|
283 |
# Finish trackio experiment on error
|
284 |
try:
|
|
|
78 |
# Add simple console callback for basic monitoring
|
79 |
from transformers import TrainerCallback
|
80 |
|
81 |
+
outer = self
|
82 |
class SimpleConsoleCallback(TrainerCallback):
|
83 |
def on_init_end(self, args, state, control, **kwargs):
|
84 |
"""Called when training initialization is complete"""
|
|
|
100 |
else:
|
101 |
lr_str = str(lr)
|
102 |
print(f"Step {step}: loss={loss_str}, lr={lr_str}")
|
103 |
+
|
104 |
+
# Persist metrics via our monitor when Trackio callback isn't active
|
105 |
+
try:
|
106 |
+
if outer.monitor:
|
107 |
+
# Avoid double logging when Trackio callback is used
|
108 |
+
if not outer.monitor.enable_tracking:
|
109 |
+
outer.monitor.log_metrics(dict(logs), step if isinstance(step, int) else None)
|
110 |
+
outer.monitor.log_system_metrics(step if isinstance(step, int) else None)
|
111 |
+
except Exception as e:
|
112 |
+
logger.warning("SimpleConsoleCallback metrics persistence failed: %s", e)
|
113 |
|
114 |
def on_train_begin(self, args, state, control, **kwargs):
|
115 |
print("🚀 Training started!")
|
|
|
120 |
def on_save(self, args, state, control, **kwargs):
|
121 |
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
|
122 |
print(f"💾 Checkpoint saved at step {step}")
|
123 |
+
try:
|
124 |
+
if outer.monitor and not outer.monitor.enable_tracking:
|
125 |
+
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{step}")
|
126 |
+
if os.path.exists(checkpoint_path):
|
127 |
+
outer.monitor.log_model_checkpoint(checkpoint_path, step if isinstance(step, int) else None)
|
128 |
+
except Exception as e:
|
129 |
+
logger.warning("SimpleConsoleCallback checkpoint persistence failed: %s", e)
|
130 |
|
131 |
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
132 |
if metrics and isinstance(metrics, dict):
|
133 |
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
|
134 |
eval_loss = metrics.get('eval_loss', 'N/A')
|
135 |
print(f"📊 Evaluation at step {step}: eval_loss={eval_loss}")
|
136 |
+
try:
|
137 |
+
if outer.monitor and not outer.monitor.enable_tracking:
|
138 |
+
outer.monitor.log_evaluation_results(dict(metrics), step if isinstance(step, int) else None)
|
139 |
+
except Exception as e:
|
140 |
+
logger.warning("SimpleConsoleCallback eval persistence failed: %s", e)
|
141 |
|
142 |
# Add console callback
|
143 |
callbacks.append(SimpleConsoleCallback())
|
144 |
logger.info("Added simple console monitoring callback")
|
145 |
|
146 |
+
# Add monitoring callback if available (always attach; it persists to dataset even if Trackio is disabled)
|
147 |
+
if self.monitor:
|
148 |
try:
|
149 |
trackio_callback = self.monitor.create_monitoring_callback()
|
150 |
if trackio_callback:
|
151 |
callbacks.append(trackio_callback)
|
152 |
+
logger.info("Added monitoring callback")
|
153 |
else:
|
154 |
+
logger.warning("Failed to create monitoring callback")
|
155 |
except Exception as e:
|
156 |
+
logger.error("Error creating monitoring callback: %s", e)
|
157 |
logger.info("Continuing with console monitoring only")
|
158 |
|
159 |
logger.info("Total callbacks: %d", len(callbacks))
|
|
|
243 |
"""Start training"""
|
244 |
logger.info("Starting training")
|
245 |
|
246 |
+
# Log configuration (always persist to dataset; Trackio if enabled)
|
247 |
+
if self.monitor:
|
248 |
+
try:
|
249 |
+
config_dict = {k: v for k, v in self.config.__dict__.items() if not k.startswith('_')}
|
250 |
+
self.monitor.log_config(config_dict)
|
251 |
+
except Exception as e:
|
252 |
+
logger.warning("Failed to log configuration: %s", e)
|
253 |
+
# Log experiment URL only if available
|
254 |
+
try:
|
255 |
+
experiment_url = self.monitor.get_experiment_url()
|
256 |
+
if experiment_url:
|
257 |
+
logger.info("Trackio experiment URL: %s", experiment_url)
|
258 |
+
except Exception:
|
259 |
+
pass
|
260 |
|
261 |
# Load checkpoint if resuming
|
262 |
if self.init_from == "resume":
|
|
|
278 |
with open(os.path.join(self.output_dir, "train_results.json"), "w") as f:
|
279 |
json.dump(train_result.metrics, f, indent=2)
|
280 |
|
281 |
+
# Log training summary (always persist to dataset; Trackio if enabled)
|
282 |
+
if self.monitor:
|
283 |
+
try:
|
284 |
+
summary = {
|
285 |
+
'final_loss': train_result.metrics.get('train_loss', 0),
|
286 |
+
'total_steps': train_result.metrics.get('train_runtime', 0),
|
287 |
+
'training_time': train_result.metrics.get('train_runtime', 0),
|
288 |
+
'output_dir': self.output_dir,
|
289 |
+
'model_name': getattr(self.config, 'model_name', 'unknown'),
|
290 |
+
}
|
291 |
+
self.monitor.log_training_summary(summary)
|
292 |
+
self.monitor.close()
|
293 |
+
except Exception as e:
|
294 |
+
logger.warning("Failed to log training summary: %s", e)
|
295 |
|
296 |
# Finish trackio experiment
|
297 |
try:
|
|
|
306 |
|
307 |
except Exception as e:
|
308 |
logger.error("Training failed: %s", e)
|
309 |
+
# Close monitoring on error (still persist final status to dataset)
|
310 |
+
if self.monitor:
|
311 |
+
try:
|
312 |
+
self.monitor.close(final_status="failed")
|
313 |
+
except Exception:
|
314 |
+
pass
|
315 |
|
316 |
# Finish trackio experiment on error
|
317 |
try:
|