Tonic commited on
Commit
c560f4f
·
1 Parent(s): 0f12d91

adds experiment id fix

Browse files
Files changed (3) hide show
  1. src/trackio.py +18 -7
  2. src/trainer.py +16 -16
  3. templates/spaces/trackio/app.py +23 -3
src/trackio.py CHANGED
@@ -65,15 +65,12 @@ def init(
65
  hf_token=hf_token,
66
  dataset_repo=dataset_repo
67
  )
68
-
69
- # Generate experiment ID - use the same format as our monitoring system
70
- experiment_id = f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
71
- _monitor.experiment_id = experiment_id
72
-
73
  logger.info(f"Trackio initialized for experiment: {exp_name}")
74
  logger.info(f"Experiment ID: {experiment_id}")
75
-
76
- return experiment_id
77
 
78
  except Exception as e:
79
  logger.error(f"Failed to initialize trackio: {e}")
@@ -128,6 +125,20 @@ def finish():
128
  except Exception as e:
129
  logger.error(f"Failed to finish trackio experiment: {e}")
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  def log_config(config: Dict[str, Any]):
132
  """
133
  Log configuration to trackio (TRL interface)
 
65
  hf_token=hf_token,
66
  dataset_repo=dataset_repo
67
  )
68
+ # The monitor constructor creates the experiment remotely and sets
69
+ # `experiment_id`. Do NOT overwrite it with a locally generated ID.
70
+ experiment_id = getattr(_monitor, "experiment_id", None)
 
 
71
  logger.info(f"Trackio initialized for experiment: {exp_name}")
72
  logger.info(f"Experiment ID: {experiment_id}")
73
+ return experiment_id or f"exp_fallback_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
 
74
 
75
  except Exception as e:
76
  logger.error(f"Failed to initialize trackio: {e}")
 
125
  except Exception as e:
126
  logger.error(f"Failed to finish trackio experiment: {e}")
127
 
128
+ def set_monitor(monitor: SmolLM3Monitor) -> None:
129
+ """Set the shared monitor instance used by this module.
130
+
131
+ This allows external code (e.g., our trainer) to create a
132
+ `SmolLM3Monitor` once and have `trackio.log/finish` operate on
133
+ the exact same object, preventing mismatched experiment IDs.
134
+ """
135
+ global _monitor
136
+ _monitor = monitor
137
+ try:
138
+ logger.info("trackio monitor set: experiment_id=%s", getattr(monitor, "experiment_id", None))
139
+ except Exception:
140
+ pass
141
+
142
  def log_config(config: Dict[str, Any]):
143
  """
144
  Log configuration to trackio (TRL interface)
src/trainer.py CHANGED
@@ -158,17 +158,23 @@ class SmolLM3Trainer:
158
 
159
  logger.info("Total callbacks: %d", len(callbacks))
160
 
161
- # Initialize trackio for TRL compatibility
162
  try:
163
  import trackio
164
- # Initialize trackio with our configuration and use the same experiment ID
165
- if self.monitor and self.monitor.experiment_id:
166
- # Use the experiment ID from our monitor
167
- experiment_id = self.monitor.experiment_id
168
- logger.info(f"Using existing experiment ID: {experiment_id}")
 
 
 
 
 
 
169
  else:
170
- # Initialize trackio with our configuration
171
- experiment_id = trackio.init(
172
  project_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'),
173
  experiment_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'),
174
  trackio_url=getattr(self.config, 'trackio_url', None),
@@ -176,15 +182,9 @@ class SmolLM3Trainer:
176
  hf_token=getattr(self.config, 'hf_token', None),
177
  dataset_repo=getattr(self.config, 'dataset_repo', None)
178
  )
179
- logger.info(f"Trackio initialized with experiment ID: {experiment_id}")
180
-
181
- # Update our monitor with the same experiment ID
182
- if self.monitor:
183
- self.monitor.experiment_id = experiment_id
184
- logger.info(f"Updated monitor with experiment ID: {experiment_id}")
185
  except Exception as e:
186
- logger.warning(f"Failed to initialize trackio: {e}")
187
- logger.info("Continuing without trackio integration")
188
 
189
  # Try SFTTrainer first (better for instruction tuning)
190
  logger.info("Creating SFTTrainer with training arguments...")
 
158
 
159
  logger.info("Total callbacks: %d", len(callbacks))
160
 
161
+ # Initialize trackio for TRL compatibility without creating a second experiment
162
  try:
163
  import trackio
164
+ if self.monitor:
165
+ # Share the same monitor/experiment with the trackio shim
166
+ try:
167
+ trackio.set_monitor(self.monitor) # type: ignore[attr-defined]
168
+ except Exception:
169
+ # Fallback: ensure the shim at least knows the current ID
170
+ pass
171
+ logger.info(
172
+ "Using shared Trackio monitor with experiment ID: %s",
173
+ getattr(self.monitor, 'experiment_id', None)
174
+ )
175
  else:
176
+ # Last resort: initialize via shim
177
+ _ = trackio.init(
178
  project_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'),
179
  experiment_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'),
180
  trackio_url=getattr(self.config, 'trackio_url', None),
 
182
  hf_token=getattr(self.config, 'hf_token', None),
183
  dataset_repo=getattr(self.config, 'dataset_repo', None)
184
  )
 
 
 
 
 
 
185
  except Exception as e:
186
+ logger.warning(f"Failed to wire trackio shim: {e}")
187
+ logger.info("Continuing without trackio shim integration")
188
 
189
  # Try SFTTrainer first (better for instruction tuning)
190
  logger.info("Creating SFTTrainer with training arguments...")
templates/spaces/trackio/app.py CHANGED
@@ -1143,12 +1143,25 @@ def create_metrics_plot(experiment_id: str, metric_name: str = "loss") -> go.Fig
1143
  )
1144
  return fig
1145
 
 
 
 
 
 
 
 
1146
  fig = px.line(df, x='step', y=metric_name, title=f'{metric_name} over time')
1147
  fig.update_layout(
1148
  xaxis_title="Training Step",
1149
  yaxis_title=metric_name.title(),
1150
  hovermode='x unified'
1151
  )
 
 
 
 
 
 
1152
  return fig
1153
 
1154
  except Exception as e:
@@ -1530,16 +1543,23 @@ def create_combined_metrics_plot(experiment_id: str) -> go.Figure:
1530
  col = (i % n_cols) + 1
1531
  color = colors[i % len(colors)]
1532
 
 
 
 
 
 
 
 
1533
  fig.add_trace(
1534
  go.Scatter(
1535
- x=df['step'].tolist(),
1536
- y=df[metric].tolist(),
1537
  mode='lines+markers',
1538
  name=metric,
1539
  line=dict(width=2, color=color),
1540
  marker=dict(size=4, color=color),
1541
  showlegend=False,
1542
- connectgaps=True
1543
  ),
1544
  row=row, col=col
1545
  )
 
1143
  )
1144
  return fig
1145
 
1146
+ # Ensure steps are numeric and monotonically increasing to avoid zig-zag lines
1147
+ try:
1148
+ df = df.copy()
1149
+ df['step'] = pd.to_numeric(df['step'], errors='coerce').fillna(-1)
1150
+ df.sort_values('step', inplace=True)
1151
+ except Exception:
1152
+ pass
1153
  fig = px.line(df, x='step', y=metric_name, title=f'{metric_name} over time')
1154
  fig.update_layout(
1155
  xaxis_title="Training Step",
1156
  yaxis_title=metric_name.title(),
1157
  hovermode='x unified'
1158
  )
1159
+ # Avoid interpolating across missing steps which can create odd visuals
1160
+ try:
1161
+ for trace in fig.data:
1162
+ trace.connectgaps = False
1163
+ except Exception:
1164
+ pass
1165
  return fig
1166
 
1167
  except Exception as e:
 
1543
  col = (i % n_cols) + 1
1544
  color = colors[i % len(colors)]
1545
 
1546
+ # Clean steps for each subplot too
1547
+ try:
1548
+ df_sub = df.copy()
1549
+ df_sub['step'] = pd.to_numeric(df_sub['step'], errors='coerce').fillna(-1)
1550
+ df_sub.sort_values('step', inplace=True)
1551
+ except Exception:
1552
+ df_sub = df
1553
  fig.add_trace(
1554
  go.Scatter(
1555
+ x=df_sub['step'].tolist(),
1556
+ y=df_sub[metric].tolist(),
1557
  mode='lines+markers',
1558
  name=metric,
1559
  line=dict(width=2, color=color),
1560
  marker=dict(size=4, color=color),
1561
  showlegend=False,
1562
+ connectgaps=False
1563
  ),
1564
  row=row, col=col
1565
  )