Spaces:
Running
Running
removes callbacks
Browse files- monitoring.py +4 -0
- trainer.py +43 -4
monitoring.py
CHANGED
@@ -309,6 +309,10 @@ class SmolLM3Monitor:
|
|
309 |
self.monitor.close()
|
310 |
except Exception as e:
|
311 |
logger.error(f"Error in on_train_end: {e}")
|
|
|
|
|
|
|
|
|
312 |
|
313 |
return TrackioCallback(self)
|
314 |
|
|
|
309 |
self.monitor.close()
|
310 |
except Exception as e:
|
311 |
logger.error(f"Error in on_train_end: {e}")
|
312 |
+
|
313 |
+
def __call__(self, *args, **kwargs):
|
314 |
+
"""Make the callback callable to avoid any issues"""
|
315 |
+
return self
|
316 |
|
317 |
return TrackioCallback(self)
|
318 |
|
trainer.py
CHANGED
@@ -61,12 +61,51 @@ class SmolLM3Trainer:
|
|
61 |
# Get data collator
|
62 |
data_collator = self.dataset.get_data_collator()
|
63 |
|
64 |
-
# Add monitoring callback
|
65 |
callbacks = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
if self.monitor and self.monitor.enable_tracking:
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
if self.use_sft_trainer:
|
72 |
# Use SFTTrainer for supervised fine-tuning
|
|
|
61 |
# Get data collator
|
62 |
data_collator = self.dataset.get_data_collator()
|
63 |
|
64 |
+
# Add monitoring callback - temporarily disabled to debug
|
65 |
callbacks = []
|
66 |
+
|
67 |
+
# Simple console callback for basic monitoring
|
68 |
+
class SimpleConsoleCallback:
|
69 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
70 |
+
"""Log metrics to console"""
|
71 |
+
if logs and isinstance(logs, dict):
|
72 |
+
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
|
73 |
+
loss = logs.get('loss', 'N/A')
|
74 |
+
lr = logs.get('learning_rate', 'N/A')
|
75 |
+
print(f"Step {step}: loss={loss:.4f}, lr={lr}")
|
76 |
+
|
77 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
78 |
+
print("🚀 Training started!")
|
79 |
+
|
80 |
+
def on_train_end(self, args, state, control, **kwargs):
|
81 |
+
print("✅ Training completed!")
|
82 |
+
|
83 |
+
def on_save(self, args, state, control, **kwargs):
|
84 |
+
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
|
85 |
+
print(f"💾 Checkpoint saved at step {step}")
|
86 |
+
|
87 |
+
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
88 |
+
if metrics and isinstance(metrics, dict):
|
89 |
+
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
|
90 |
+
eval_loss = metrics.get('eval_loss', 'N/A')
|
91 |
+
print(f"📊 Evaluation at step {step}: eval_loss={eval_loss}")
|
92 |
+
|
93 |
+
# Add simple console callback
|
94 |
+
callbacks.append(SimpleConsoleCallback())
|
95 |
+
logger.info("Added simple console monitoring callback")
|
96 |
+
|
97 |
+
# Try to add Trackio callback if available
|
98 |
if self.monitor and self.monitor.enable_tracking:
|
99 |
+
try:
|
100 |
+
trackio_callback = self.monitor.create_monitoring_callback()
|
101 |
+
if trackio_callback:
|
102 |
+
callbacks.append(trackio_callback)
|
103 |
+
logger.info("Added Trackio monitoring callback")
|
104 |
+
else:
|
105 |
+
logger.warning("Failed to create Trackio callback")
|
106 |
+
except Exception as e:
|
107 |
+
logger.error(f"Error creating Trackio callback: {e}")
|
108 |
+
logger.info("Continuing with console monitoring only")
|
109 |
|
110 |
if self.use_sft_trainer:
|
111 |
# Use SFTTrainer for supervised fine-tuning
|