Spaces:
Running
Running
adds experiment id fix
Browse files- src/trackio.py +18 -7
- src/trainer.py +16 -16
- templates/spaces/trackio/app.py +23 -3
src/trackio.py
CHANGED
@@ -65,15 +65,12 @@ def init(
|
|
65 |
hf_token=hf_token,
|
66 |
dataset_repo=dataset_repo
|
67 |
)
|
68 |
-
|
69 |
-
#
|
70 |
-
experiment_id =
|
71 |
-
_monitor.experiment_id = experiment_id
|
72 |
-
|
73 |
logger.info(f"Trackio initialized for experiment: {exp_name}")
|
74 |
logger.info(f"Experiment ID: {experiment_id}")
|
75 |
-
|
76 |
-
return experiment_id
|
77 |
|
78 |
except Exception as e:
|
79 |
logger.error(f"Failed to initialize trackio: {e}")
|
@@ -128,6 +125,20 @@ def finish():
|
|
128 |
except Exception as e:
|
129 |
logger.error(f"Failed to finish trackio experiment: {e}")
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
def log_config(config: Dict[str, Any]):
|
132 |
"""
|
133 |
Log configuration to trackio (TRL interface)
|
|
|
65 |
hf_token=hf_token,
|
66 |
dataset_repo=dataset_repo
|
67 |
)
|
68 |
+
# The monitor constructor creates the experiment remotely and sets
|
69 |
+
# `experiment_id`. Do NOT overwrite it with a locally generated ID.
|
70 |
+
experiment_id = getattr(_monitor, "experiment_id", None)
|
|
|
|
|
71 |
logger.info(f"Trackio initialized for experiment: {exp_name}")
|
72 |
logger.info(f"Experiment ID: {experiment_id}")
|
73 |
+
return experiment_id or f"exp_fallback_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
|
74 |
|
75 |
except Exception as e:
|
76 |
logger.error(f"Failed to initialize trackio: {e}")
|
|
|
125 |
except Exception as e:
|
126 |
logger.error(f"Failed to finish trackio experiment: {e}")
|
127 |
|
128 |
+
def set_monitor(monitor: SmolLM3Monitor) -> None:
|
129 |
+
"""Set the shared monitor instance used by this module.
|
130 |
+
|
131 |
+
This allows external code (e.g., our trainer) to create a
|
132 |
+
`SmolLM3Monitor` once and have `trackio.log/finish` operate on
|
133 |
+
the exact same object, preventing mismatched experiment IDs.
|
134 |
+
"""
|
135 |
+
global _monitor
|
136 |
+
_monitor = monitor
|
137 |
+
try:
|
138 |
+
logger.info("trackio monitor set: experiment_id=%s", getattr(monitor, "experiment_id", None))
|
139 |
+
except Exception:
|
140 |
+
pass
|
141 |
+
|
142 |
def log_config(config: Dict[str, Any]):
|
143 |
"""
|
144 |
Log configuration to trackio (TRL interface)
|
src/trainer.py
CHANGED
@@ -158,17 +158,23 @@ class SmolLM3Trainer:
|
|
158 |
|
159 |
logger.info("Total callbacks: %d", len(callbacks))
|
160 |
|
161 |
-
# Initialize trackio for TRL compatibility
|
162 |
try:
|
163 |
import trackio
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
else:
|
170 |
-
#
|
171 |
-
|
172 |
project_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'),
|
173 |
experiment_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'),
|
174 |
trackio_url=getattr(self.config, 'trackio_url', None),
|
@@ -176,15 +182,9 @@ class SmolLM3Trainer:
|
|
176 |
hf_token=getattr(self.config, 'hf_token', None),
|
177 |
dataset_repo=getattr(self.config, 'dataset_repo', None)
|
178 |
)
|
179 |
-
logger.info(f"Trackio initialized with experiment ID: {experiment_id}")
|
180 |
-
|
181 |
-
# Update our monitor with the same experiment ID
|
182 |
-
if self.monitor:
|
183 |
-
self.monitor.experiment_id = experiment_id
|
184 |
-
logger.info(f"Updated monitor with experiment ID: {experiment_id}")
|
185 |
except Exception as e:
|
186 |
-
logger.warning(f"Failed to
|
187 |
-
logger.info("Continuing without trackio integration")
|
188 |
|
189 |
# Try SFTTrainer first (better for instruction tuning)
|
190 |
logger.info("Creating SFTTrainer with training arguments...")
|
|
|
158 |
|
159 |
logger.info("Total callbacks: %d", len(callbacks))
|
160 |
|
161 |
+
# Initialize trackio for TRL compatibility without creating a second experiment
|
162 |
try:
|
163 |
import trackio
|
164 |
+
if self.monitor:
|
165 |
+
# Share the same monitor/experiment with the trackio shim
|
166 |
+
try:
|
167 |
+
trackio.set_monitor(self.monitor) # type: ignore[attr-defined]
|
168 |
+
except Exception:
|
169 |
+
# Fallback: ensure the shim at least knows the current ID
|
170 |
+
pass
|
171 |
+
logger.info(
|
172 |
+
"Using shared Trackio monitor with experiment ID: %s",
|
173 |
+
getattr(self.monitor, 'experiment_id', None)
|
174 |
+
)
|
175 |
else:
|
176 |
+
# Last resort: initialize via shim
|
177 |
+
_ = trackio.init(
|
178 |
project_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'),
|
179 |
experiment_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'),
|
180 |
trackio_url=getattr(self.config, 'trackio_url', None),
|
|
|
182 |
hf_token=getattr(self.config, 'hf_token', None),
|
183 |
dataset_repo=getattr(self.config, 'dataset_repo', None)
|
184 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
except Exception as e:
|
186 |
+
logger.warning(f"Failed to wire trackio shim: {e}")
|
187 |
+
logger.info("Continuing without trackio shim integration")
|
188 |
|
189 |
# Try SFTTrainer first (better for instruction tuning)
|
190 |
logger.info("Creating SFTTrainer with training arguments...")
|
templates/spaces/trackio/app.py
CHANGED
@@ -1143,12 +1143,25 @@ def create_metrics_plot(experiment_id: str, metric_name: str = "loss") -> go.Fig
|
|
1143 |
)
|
1144 |
return fig
|
1145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1146 |
fig = px.line(df, x='step', y=metric_name, title=f'{metric_name} over time')
|
1147 |
fig.update_layout(
|
1148 |
xaxis_title="Training Step",
|
1149 |
yaxis_title=metric_name.title(),
|
1150 |
hovermode='x unified'
|
1151 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
1152 |
return fig
|
1153 |
|
1154 |
except Exception as e:
|
@@ -1530,16 +1543,23 @@ def create_combined_metrics_plot(experiment_id: str) -> go.Figure:
|
|
1530 |
col = (i % n_cols) + 1
|
1531 |
color = colors[i % len(colors)]
|
1532 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1533 |
fig.add_trace(
|
1534 |
go.Scatter(
|
1535 |
-
x=
|
1536 |
-
y=
|
1537 |
mode='lines+markers',
|
1538 |
name=metric,
|
1539 |
line=dict(width=2, color=color),
|
1540 |
marker=dict(size=4, color=color),
|
1541 |
showlegend=False,
|
1542 |
-
connectgaps=
|
1543 |
),
|
1544 |
row=row, col=col
|
1545 |
)
|
|
|
1143 |
)
|
1144 |
return fig
|
1145 |
|
1146 |
+
# Ensure steps are numeric and monotonically increasing to avoid zig-zag lines
|
1147 |
+
try:
|
1148 |
+
df = df.copy()
|
1149 |
+
df['step'] = pd.to_numeric(df['step'], errors='coerce').fillna(-1)
|
1150 |
+
df.sort_values('step', inplace=True)
|
1151 |
+
except Exception:
|
1152 |
+
pass
|
1153 |
fig = px.line(df, x='step', y=metric_name, title=f'{metric_name} over time')
|
1154 |
fig.update_layout(
|
1155 |
xaxis_title="Training Step",
|
1156 |
yaxis_title=metric_name.title(),
|
1157 |
hovermode='x unified'
|
1158 |
)
|
1159 |
+
# Avoid interpolating across missing steps which can create odd visuals
|
1160 |
+
try:
|
1161 |
+
for trace in fig.data:
|
1162 |
+
trace.connectgaps = False
|
1163 |
+
except Exception:
|
1164 |
+
pass
|
1165 |
return fig
|
1166 |
|
1167 |
except Exception as e:
|
|
|
1543 |
col = (i % n_cols) + 1
|
1544 |
color = colors[i % len(colors)]
|
1545 |
|
1546 |
+
# Clean steps for each subplot too
|
1547 |
+
try:
|
1548 |
+
df_sub = df.copy()
|
1549 |
+
df_sub['step'] = pd.to_numeric(df_sub['step'], errors='coerce').fillna(-1)
|
1550 |
+
df_sub.sort_values('step', inplace=True)
|
1551 |
+
except Exception:
|
1552 |
+
df_sub = df
|
1553 |
fig.add_trace(
|
1554 |
go.Scatter(
|
1555 |
+
x=df_sub['step'].tolist(),
|
1556 |
+
y=df_sub[metric].tolist(),
|
1557 |
mode='lines+markers',
|
1558 |
name=metric,
|
1559 |
line=dict(width=2, color=color),
|
1560 |
marker=dict(size=4, color=color),
|
1561 |
showlegend=False,
|
1562 |
+
connectgaps=False
|
1563 |
),
|
1564 |
row=row, col=col
|
1565 |
)
|