Tonic commited on
Commit
924581c
·
1 Parent(s): 81f39f1

improve dataset utils non destructive writes

Browse files
scripts/trackio_tonic/dataset_utils.py CHANGED
@@ -90,82 +90,180 @@ class TrackioDatasetManager:
90
 
91
  def _validate_experiment_structure(self, experiment: Dict[str, Any]) -> bool:
92
  """
93
- Validate that an experiment has the required structure.
94
-
95
- Args:
96
- experiment (Dict[str, Any]): Experiment dictionary to validate
97
-
98
- Returns:
99
- bool: True if experiment structure is valid
100
  """
101
- required_fields = [
102
- 'experiment_id', 'name', 'description', 'created_at',
103
- 'status', 'metrics', 'parameters', 'artifacts', 'logs'
104
- ]
105
-
106
- for field in required_fields:
107
- if field not in experiment:
108
- logger.warning(f"⚠️ Missing required field '{field}' in experiment")
109
- return False
110
-
111
- # Validate JSON fields
112
- json_fields = ['metrics', 'parameters', 'artifacts', 'logs']
113
- for field in json_fields:
114
- if isinstance(experiment[field], str):
115
- try:
116
- json.loads(experiment[field])
117
- except json.JSONDecodeError:
118
- logger.warning(f"⚠️ Invalid JSON in field '{field}' for experiment {experiment.get('experiment_id')}")
119
- return False
120
-
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  return True
122
 
123
  def save_experiments(self, experiments: List[Dict[str, Any]], commit_message: Optional[str] = None) -> bool:
124
  """
125
- Save a list of experiments to the dataset, preserving data integrity.
126
-
127
- Args:
128
- experiments (List[Dict[str, Any]]): List of experiment dictionaries
129
- commit_message (Optional[str]): Custom commit message
130
-
131
- Returns:
132
- bool: True if save was successful, False otherwise
133
  """
134
  try:
135
  if not experiments:
136
  logger.warning("⚠️ No experiments to save")
137
  return False
138
-
139
- # Validate all experiments before saving
140
- valid_experiments = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  for exp in experiments:
142
- if self._validate_experiment_structure(exp):
143
- # Ensure last_updated is set
144
- if 'last_updated' not in exp:
145
- exp['last_updated'] = datetime.now().isoformat()
146
- valid_experiments.append(exp)
147
- else:
148
  logger.error(f"❌ Invalid experiment structure: {exp.get('experiment_id', 'unknown')}")
149
  return False
150
-
151
- # Create dataset
152
- dataset = Dataset.from_list(valid_experiments)
153
-
154
- # Generate commit message if not provided
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  if not commit_message:
156
- commit_message = f"Update dataset with {len(valid_experiments)} experiments ({datetime.now().isoformat()})"
157
-
158
- # Push to hub
159
  dataset.push_to_hub(
160
  self.dataset_repo,
161
  token=self.hf_token,
162
  private=True,
163
  commit_message=commit_message
164
  )
165
-
166
- logger.info(f"✅ Successfully saved {len(valid_experiments)} experiments to {self.dataset_repo}")
167
  return True
168
-
169
  except Exception as e:
170
  logger.error(f"❌ Failed to save experiments to dataset: {e}")
171
  return False
 
90
 
91
  def _validate_experiment_structure(self, experiment: Dict[str, Any]) -> bool:
92
  """
93
+ Validate and SANITIZE an experiment structure to prevent destructive failures.
94
+
95
+ - Requires 'experiment_id'; otherwise skip the row.
96
+ - Fills defaults for missing non-JSON fields.
97
+ - Normalizes JSON fields to valid JSON strings.
 
 
98
  """
99
+ if not experiment.get('experiment_id'):
100
+ logger.warning("⚠️ Missing required field 'experiment_id' in experiment; skipping row")
101
+ return False
102
+
103
+ defaults = {
104
+ 'name': '',
105
+ 'description': '',
106
+ 'created_at': datetime.now().isoformat(),
107
+ 'status': 'running',
108
+ }
109
+ for key, default_value in defaults.items():
110
+ if experiment.get(key) in (None, ''):
111
+ experiment[key] = default_value
112
+
113
+ def _ensure_json_string(field_name: str, default_value: Any):
114
+ raw_value = experiment.get(field_name)
115
+ try:
116
+ if isinstance(raw_value, str):
117
+ if raw_value.strip() == '':
118
+ experiment[field_name] = json.dumps(default_value, default=str)
119
+ else:
120
+ json.loads(raw_value)
121
+ else:
122
+ experiment[field_name] = json.dumps(
123
+ raw_value if raw_value is not None else default_value,
124
+ default=str
125
+ )
126
+ except Exception:
127
+ experiment[field_name] = json.dumps(default_value, default=str)
128
+
129
+ for json_field, default in (('metrics', []), ('parameters', {}), ('artifacts', []), ('logs', [])):
130
+ _ensure_json_string(json_field, default)
131
+
132
  return True
133
 
134
  def save_experiments(self, experiments: List[Dict[str, Any]], commit_message: Optional[str] = None) -> bool:
135
  """
136
+ Save experiments using a non-destructive UNION-MERGE by experiment_id.
137
+
138
+ - Loads existing experiments and merges JSON fields non-destructively
139
+ - Incoming scalar fields override existing scalars
140
+ - JSON fields are merged with de-duplication
 
 
 
141
  """
142
  try:
143
  if not experiments:
144
  logger.warning("⚠️ No experiments to save")
145
  return False
146
+
147
+ # Helpers
148
+ def _parse_json_field(value, default):
149
+ try:
150
+ if value is None:
151
+ return default
152
+ if isinstance(value, str):
153
+ return json.loads(value) if value else default
154
+ return value
155
+ except Exception:
156
+ return default
157
+
158
+ def _metrics_key(entry: Dict[str, Any]):
159
+ if isinstance(entry, dict):
160
+ return (entry.get('step'), entry.get('timestamp'))
161
+ return (None, json.dumps(entry, sort_keys=True))
162
+
163
+ # Load existing experiments for union merge
164
+ existing = {}
165
+ try:
166
+ for row in self.load_existing_experiments():
167
+ exp_id = row.get('experiment_id')
168
+ if exp_id:
169
+ existing[exp_id] = row
170
+ except Exception:
171
+ existing = {}
172
+
173
+ merged_map: Dict[str, Dict[str, Any]] = {exp_id: row for exp_id, row in existing.items()}
174
+
175
+ # Validate and merge incoming experiments
176
  for exp in experiments:
177
+ if not self._validate_experiment_structure(exp):
 
 
 
 
 
178
  logger.error(f"❌ Invalid experiment structure: {exp.get('experiment_id', 'unknown')}")
179
  return False
180
+ exp_id = exp['experiment_id']
181
+ incoming = exp
182
+ if exp_id not in merged_map:
183
+ incoming['last_updated'] = incoming.get('last_updated') or datetime.now().isoformat()
184
+ merged_map[exp_id] = incoming
185
+ continue
186
+
187
+ # Merge with existing
188
+ base = merged_map[exp_id]
189
+ base_metrics = _parse_json_field(base.get('metrics'), [])
190
+ base_params = _parse_json_field(base.get('parameters'), {})
191
+ base_artifacts = _parse_json_field(base.get('artifacts'), [])
192
+ base_logs = _parse_json_field(base.get('logs'), [])
193
+ inc_metrics = _parse_json_field(incoming.get('metrics'), [])
194
+ inc_params = _parse_json_field(incoming.get('parameters'), {})
195
+ inc_artifacts = _parse_json_field(incoming.get('artifacts'), [])
196
+ inc_logs = _parse_json_field(incoming.get('logs'), [])
197
+
198
+ # Merge metrics with de-dup
199
+ merged_metrics = []
200
+ seen = set()
201
+ for entry in list(base_metrics) + list(inc_metrics):
202
+ try:
203
+ key = _metrics_key(entry)
204
+ except Exception:
205
+ key = (None, None)
206
+ if key not in seen:
207
+ seen.add(key)
208
+ merged_metrics.append(entry)
209
+
210
+ # Merge params (incoming overrides)
211
+ merged_params = {}
212
+ if isinstance(base_params, dict):
213
+ merged_params.update(base_params)
214
+ if isinstance(inc_params, dict):
215
+ merged_params.update(inc_params)
216
+
217
+ # Merge artifacts/logs with de-dup while preserving order
218
+ def _dedup_list(lst):
219
+ out = []
220
+ seen_local = set()
221
+ for item in lst:
222
+ key = json.dumps(item, sort_keys=True, default=str) if not isinstance(item, str) else item
223
+ if key not in seen_local:
224
+ seen_local.add(key)
225
+ out.append(item)
226
+ return out
227
+
228
+ merged_artifacts = _dedup_list(list(base_artifacts) + list(inc_artifacts))
229
+ merged_logs = _dedup_list(list(base_logs) + list(inc_logs))
230
+
231
+ # Rebuild merged record preferring incoming scalars
232
+ merged_rec = dict(base)
233
+ merged_rec.update({k: v for k, v in incoming.items() if k not in ('metrics', 'parameters', 'artifacts', 'logs')})
234
+ merged_rec['metrics'] = json.dumps(merged_metrics, default=str)
235
+ merged_rec['parameters'] = json.dumps(merged_params, default=str)
236
+ merged_rec['artifacts'] = json.dumps(merged_artifacts, default=str)
237
+ merged_rec['logs'] = json.dumps(merged_logs, default=str)
238
+ merged_rec['last_updated'] = datetime.now().isoformat()
239
+ merged_map[exp_id] = merged_rec
240
+
241
+ # Normalize final list
242
+ normalized = []
243
+ for rec in merged_map.values():
244
+ for f, default in (('metrics', []), ('parameters', {}), ('artifacts', []), ('logs', [])):
245
+ val = rec.get(f)
246
+ if not isinstance(val, str):
247
+ rec[f] = json.dumps(val if val is not None else default, default=str)
248
+ if 'last_updated' not in rec:
249
+ rec['last_updated'] = datetime.now().isoformat()
250
+ normalized.append(rec)
251
+
252
+ dataset = Dataset.from_list(normalized)
253
+
254
  if not commit_message:
255
+ commit_message = f"Union-merge update with {len(normalized)} experiments ({datetime.now().isoformat()})"
256
+
 
257
  dataset.push_to_hub(
258
  self.dataset_repo,
259
  token=self.hf_token,
260
  private=True,
261
  commit_message=commit_message
262
  )
263
+
264
+ logger.info(f"✅ Successfully saved {len(normalized)} experiments (union-merged) to {self.dataset_repo}")
265
  return True
266
+
267
  except Exception as e:
268
  logger.error(f"❌ Failed to save experiments to dataset: {e}")
269
  return False
src/dataset_utils.py CHANGED
@@ -90,34 +90,62 @@ class TrackioDatasetManager:
90
 
91
  def _validate_experiment_structure(self, experiment: Dict[str, Any]) -> bool:
92
  """
93
- Validate that an experiment has the required structure.
94
-
 
 
 
 
 
 
 
 
95
  Args:
96
- experiment (Dict[str, Any]): Experiment dictionary to validate
97
-
98
  Returns:
99
- bool: True if experiment structure is valid
100
  """
101
- required_fields = [
102
- 'experiment_id', 'name', 'description', 'created_at',
103
- 'status', 'metrics', 'parameters', 'artifacts', 'logs'
104
- ]
105
-
106
- for field in required_fields:
107
- if field not in experiment:
108
- logger.warning(f"⚠️ Missing required field '{field}' in experiment")
109
- return False
110
-
111
- # Validate JSON fields
112
- json_fields = ['metrics', 'parameters', 'artifacts', 'logs']
113
- for field in json_fields:
114
- if isinstance(experiment[field], str):
115
- try:
116
- json.loads(experiment[field])
117
- except json.JSONDecodeError:
118
- logger.warning(f"⚠️ Invalid JSON in field '{field}' for experiment {experiment.get('experiment_id')}")
119
- return False
120
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  return True
122
 
123
  def save_experiments(self, experiments: List[Dict[str, Any]], commit_message: Optional[str] = None) -> bool:
 
90
 
91
  def _validate_experiment_structure(self, experiment: Dict[str, Any]) -> bool:
92
  """
93
+ Validate and SANITIZE an experiment structure.
94
+
95
+ This function is intentionally lenient to avoid dropping any
96
+ existing rows from the remote dataset during union-merge saves.
97
+
98
+ Rules:
99
+ - 'experiment_id' must exist; otherwise the row is skipped
100
+ - All other required fields are auto-filled with safe defaults
101
+ - JSON fields are normalized to valid JSON strings
102
+
103
  Args:
104
+ experiment (Dict[str, Any]): Experiment dictionary to validate/sanitize
105
+
106
  Returns:
107
+ bool: True if experiment has (or was sanitized to) a valid structure.
108
  """
109
+ # Hard requirement: experiment_id must be present
110
+ if not experiment.get('experiment_id'):
111
+ logger.warning("⚠️ Missing required field 'experiment_id' in experiment; skipping row")
112
+ return False
113
+
114
+ # Fill defaults for non-JSON scalar fields
115
+ defaults = {
116
+ 'name': '',
117
+ 'description': '',
118
+ 'created_at': datetime.now().isoformat(),
119
+ 'status': 'running',
120
+ }
121
+ for key, default_value in defaults.items():
122
+ if experiment.get(key) in (None, ''):
123
+ experiment[key] = default_value
124
+
125
+ # Normalize JSON fields to valid JSON strings
126
+ def _ensure_json_string(field_name: str, default_value: Any):
127
+ raw_value = experiment.get(field_name)
128
+ try:
129
+ if isinstance(raw_value, str):
130
+ # Validate JSON string; if empty use default
131
+ if raw_value.strip() == '':
132
+ experiment[field_name] = json.dumps(default_value, default=str)
133
+ else:
134
+ json.loads(raw_value)
135
+ # keep as-is if it's valid JSON
136
+ else:
137
+ # Convert object to JSON string
138
+ experiment[field_name] = json.dumps(
139
+ raw_value if raw_value is not None else default_value,
140
+ default=str
141
+ )
142
+ except Exception:
143
+ # On any error, fall back to default JSON
144
+ experiment[field_name] = json.dumps(default_value, default=str)
145
+
146
+ for json_field, default in (('metrics', []), ('parameters', {}), ('artifacts', []), ('logs', [])):
147
+ _ensure_json_string(json_field, default)
148
+
149
  return True
150
 
151
  def save_experiments(self, experiments: List[Dict[str, Any]], commit_message: Optional[str] = None) -> bool:
src/monitoring.py CHANGED
@@ -206,12 +206,12 @@ class SmolLM3Monitor:
206
  def _save_to_hf_dataset(self, experiment_data: Dict[str, Any]):
207
  """Save experiment data to HF Dataset with data preservation using dataset manager.
208
 
209
- This method MERGES with any existing experiment entry to avoid overwriting data:
210
- - If experiment_data contains a 'metrics' list, append new metric entries (with de-dup by step+timestamp)
211
- and store using the nested structure expected by the Trackio Space (each entry has
212
- {timestamp, step, metrics: {...}}).
213
- - Otherwise, treat experiment_data as a parameters update and dict-merge it into existing parameters.
214
- - Artifacts are merged and de-duplicated by their string value.
215
  """
216
  if not self.dataset_manager:
217
  logger.warning("⚠️ Dataset manager not available")
@@ -287,10 +287,21 @@ class SmolLM3Monitor:
287
  merged_metrics.append(nested)
288
  # else: ignore invalid metrics payload
289
  else:
290
- # Treat as parameters update; merge dict
291
  try:
292
  if isinstance(experiment_data, dict):
293
- merged_parameters.update(experiment_data)
 
 
 
 
 
 
 
 
 
 
 
294
  except Exception:
295
  pass
296
 
 
206
  def _save_to_hf_dataset(self, experiment_data: Dict[str, Any]):
207
  """Save experiment data to HF Dataset with data preservation using dataset manager.
208
 
209
+ Non-destructive rules:
210
+ - Merge with existing experiment by experiment_id
211
+ - Metrics: append with de-dup (by step+timestamp), preserve nested format {timestamp, step, metrics}
212
+ - Parameters: dict-merge (incoming overrides keys)
213
+ - Artifacts/logs: union with de-dup, preserve order
214
+ - Top-level scalar fields (e.g., status, name, description, created_at) update only when provided
215
  """
216
  if not self.dataset_manager:
217
  logger.warning("⚠️ Dataset manager not available")
 
287
  merged_metrics.append(nested)
288
  # else: ignore invalid metrics payload
289
  else:
290
+ # Treat as parameters and/or top-level updates
291
  try:
292
  if isinstance(experiment_data, dict):
293
+ # Extract known top-level fields (do not bury into parameters)
294
+ top_level_updates = {}
295
+ for k in ['status', 'name', 'description', 'created_at', 'experiment_end_time', 'final_metrics_count', 'total_artifacts']:
296
+ if k in experiment_data:
297
+ top_level_updates[k] = experiment_data[k]
298
+ # Remove them from parameters payload
299
+ param_updates = {k: v for k, v in experiment_data.items() if k not in top_level_updates}
300
+ # Apply param updates
301
+ merged_parameters.update(param_updates)
302
+ # Apply top-level updates to `existing` so they are reflected in the final record below
303
+ for k, v in top_level_updates.items():
304
+ existing[k] = v
305
  except Exception:
306
  pass
307
 
templates/spaces/trackio/dataset_utils.py CHANGED
@@ -90,34 +90,45 @@ class TrackioDatasetManager:
90
 
91
  def _validate_experiment_structure(self, experiment: Dict[str, Any]) -> bool:
92
  """
93
- Validate that an experiment has the required structure.
94
-
95
- Args:
96
- experiment (Dict[str, Any]): Experiment dictionary to validate
97
-
98
- Returns:
99
- bool: True if experiment structure is valid
100
  """
101
- required_fields = [
102
- 'experiment_id', 'name', 'description', 'created_at',
103
- 'status', 'metrics', 'parameters', 'artifacts', 'logs'
104
- ]
105
-
106
- for field in required_fields:
107
- if field not in experiment:
108
- logger.warning(f"⚠️ Missing required field '{field}' in experiment")
109
- return False
110
-
111
- # Validate JSON fields
112
- json_fields = ['metrics', 'parameters', 'artifacts', 'logs']
113
- for field in json_fields:
114
- if isinstance(experiment[field], str):
115
- try:
116
- json.loads(experiment[field])
117
- except json.JSONDecodeError:
118
- logger.warning(f"⚠️ Invalid JSON in field '{field}' for experiment {experiment.get('experiment_id')}")
119
- return False
120
-
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  return True
122
 
123
  def save_experiments(self, experiments: List[Dict[str, Any]], commit_message: Optional[str] = None) -> bool:
 
90
 
91
  def _validate_experiment_structure(self, experiment: Dict[str, Any]) -> bool:
92
  """
93
+ Validate and SANITIZE an experiment structure to prevent destructive failures.
94
+
95
+ - Requires 'experiment_id'; otherwise skip the row.
96
+ - Fills defaults for missing non-JSON fields.
97
+ - Normalizes JSON fields to valid JSON strings.
 
 
98
  """
99
+ if not experiment.get('experiment_id'):
100
+ logger.warning("⚠️ Missing required field 'experiment_id' in experiment; skipping row")
101
+ return False
102
+
103
+ defaults = {
104
+ 'name': '',
105
+ 'description': '',
106
+ 'created_at': datetime.now().isoformat(),
107
+ 'status': 'running',
108
+ }
109
+ for key, default_value in defaults.items():
110
+ if experiment.get(key) in (None, ''):
111
+ experiment[key] = default_value
112
+
113
+ def _ensure_json_string(field_name: str, default_value: Any):
114
+ raw_value = experiment.get(field_name)
115
+ try:
116
+ if isinstance(raw_value, str):
117
+ if raw_value.strip() == '':
118
+ experiment[field_name] = json.dumps(default_value, default=str)
119
+ else:
120
+ json.loads(raw_value)
121
+ else:
122
+ experiment[field_name] = json.dumps(
123
+ raw_value if raw_value is not None else default_value,
124
+ default=str
125
+ )
126
+ except Exception:
127
+ experiment[field_name] = json.dumps(default_value, default=str)
128
+
129
+ for json_field, default in (('metrics', []), ('parameters', {}), ('artifacts', []), ('logs', [])):
130
+ _ensure_json_string(json_field, default)
131
+
132
  return True
133
 
134
  def save_experiments(self, experiments: List[Dict[str, Any]], commit_message: Optional[str] = None) -> bool: