Tonic commited on
Commit
8606a9a
·
verified ·
1 Parent(s): 5fe45a6

removes callbacks

Browse files
Files changed (2) hide show
  1. monitoring.py +4 -0
  2. trainer.py +43 -4
monitoring.py CHANGED
@@ -309,6 +309,10 @@ class SmolLM3Monitor:
309
  self.monitor.close()
310
  except Exception as e:
311
  logger.error(f"Error in on_train_end: {e}")
 
 
 
 
312
 
313
  return TrackioCallback(self)
314
 
 
309
  self.monitor.close()
310
  except Exception as e:
311
  logger.error(f"Error in on_train_end: {e}")
312
+
313
+ def __call__(self, *args, **kwargs):
314
+ """Make the callback callable to avoid any issues"""
315
+ return self
316
 
317
  return TrackioCallback(self)
318
 
trainer.py CHANGED
@@ -61,12 +61,51 @@ class SmolLM3Trainer:
61
  # Get data collator
62
  data_collator = self.dataset.get_data_collator()
63
 
64
- # Add monitoring callback
65
  callbacks = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if self.monitor and self.monitor.enable_tracking:
67
- trackio_callback = self.monitor.create_monitoring_callback()
68
- if trackio_callback:
69
- callbacks.append(trackio_callback)
 
 
 
 
 
 
 
70
 
71
  if self.use_sft_trainer:
72
  # Use SFTTrainer for supervised fine-tuning
 
61
  # Get data collator
62
  data_collator = self.dataset.get_data_collator()
63
 
64
+ # Add monitoring callback - temporarily disabled to debug
65
  callbacks = []
66
+
67
+ # Simple console callback for basic monitoring
68
+ class SimpleConsoleCallback:
69
+ def on_log(self, args, state, control, logs=None, **kwargs):
70
+ """Log metrics to console"""
71
+ if logs and isinstance(logs, dict):
72
+ step = state.global_step if hasattr(state, 'global_step') else 'unknown'
73
+ loss = logs.get('loss', 'N/A')
74
+ lr = logs.get('learning_rate', 'N/A')
75
+ print(f"Step {step}: loss={loss:.4f}, lr={lr}")
76
+
77
+ def on_train_begin(self, args, state, control, **kwargs):
78
+ print("🚀 Training started!")
79
+
80
+ def on_train_end(self, args, state, control, **kwargs):
81
+ print("✅ Training completed!")
82
+
83
+ def on_save(self, args, state, control, **kwargs):
84
+ step = state.global_step if hasattr(state, 'global_step') else 'unknown'
85
+ print(f"💾 Checkpoint saved at step {step}")
86
+
87
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
88
+ if metrics and isinstance(metrics, dict):
89
+ step = state.global_step if hasattr(state, 'global_step') else 'unknown'
90
+ eval_loss = metrics.get('eval_loss', 'N/A')
91
+ print(f"📊 Evaluation at step {step}: eval_loss={eval_loss}")
92
+
93
+ # Add simple console callback
94
+ callbacks.append(SimpleConsoleCallback())
95
+ logger.info("Added simple console monitoring callback")
96
+
97
+ # Try to add Trackio callback if available
98
  if self.monitor and self.monitor.enable_tracking:
99
+ try:
100
+ trackio_callback = self.monitor.create_monitoring_callback()
101
+ if trackio_callback:
102
+ callbacks.append(trackio_callback)
103
+ logger.info("Added Trackio monitoring callback")
104
+ else:
105
+ logger.warning("Failed to create Trackio callback")
106
+ except Exception as e:
107
+ logger.error(f"Error creating Trackio callback: {e}")
108
+ logger.info("Continuing with console monitoring only")
109
 
110
  if self.use_sft_trainer:
111
  # Use SFTTrainer for supervised fine-tuning