Tonic commited on
Commit
b11b94b
·
1 Parent(s): 3331c7f

adds improved dataset utils in tracktonic

Browse files
templates/spaces/trackio/dataset_utils.py CHANGED
@@ -122,12 +122,20 @@ class TrackioDatasetManager:
122
 
123
  def save_experiments(self, experiments: List[Dict[str, Any]], commit_message: Optional[str] = None) -> bool:
124
  """
125
- Save a list of experiments to the dataset, preserving data integrity.
126
-
 
 
 
 
 
 
 
 
127
  Args:
128
  experiments (List[Dict[str, Any]]): List of experiment dictionaries
129
  commit_message (Optional[str]): Custom commit message
130
-
131
  Returns:
132
  bool: True if save was successful, False otherwise
133
  """
@@ -136,24 +144,120 @@ class TrackioDatasetManager:
136
  logger.warning("⚠️ No experiments to save")
137
  return False
138
 
139
- # Validate all experiments before saving
140
- valid_experiments = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  for exp in experiments:
142
- if self._validate_experiment_structure(exp):
143
- # Ensure last_updated is set
144
- if 'last_updated' not in exp:
145
- exp['last_updated'] = datetime.now().isoformat()
146
- valid_experiments.append(exp)
147
- else:
148
  logger.error(f"❌ Invalid experiment structure: {exp.get('experiment_id', 'unknown')}")
149
  return False
150
-
151
- # Create dataset
152
- dataset = Dataset.from_list(valid_experiments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  # Generate commit message if not provided
155
  if not commit_message:
156
- commit_message = f"Update dataset with {len(valid_experiments)} experiments ({datetime.now().isoformat()})"
157
 
158
  # Push to hub
159
  dataset.push_to_hub(
@@ -163,7 +267,7 @@ class TrackioDatasetManager:
163
  commit_message=commit_message
164
  )
165
 
166
- logger.info(f"✅ Successfully saved {len(valid_experiments)} experiments to {self.dataset_repo}")
167
  return True
168
 
169
  except Exception as e:
 
122
 
123
  def save_experiments(self, experiments: List[Dict[str, Any]], commit_message: Optional[str] = None) -> bool:
124
  """
125
+ Save a list of experiments to the dataset using a non-destructive union merge.
126
+
127
+ - Loads existing experiments (if any) and builds a union by `experiment_id`.
128
+ - For overlapping IDs, merges JSON fields:
129
+ - metrics: concatenates lists and de-duplicates by (step, timestamp) for nested entries
130
+ - parameters: dict-update (new values override)
131
+ - artifacts: union with de-dup
132
+ - logs: concatenation with de-dup
133
+ - Non-JSON scalar fields from incoming experiments take precedence.
134
+
135
  Args:
136
  experiments (List[Dict[str, Any]]): List of experiment dictionaries
137
  commit_message (Optional[str]): Custom commit message
138
+
139
  Returns:
140
  bool: True if save was successful, False otherwise
141
  """
 
144
  logger.warning("⚠️ No experiments to save")
145
  return False
146
 
147
+ # Helpers
148
+ def _parse_json_field(value, default):
149
+ try:
150
+ if value is None:
151
+ return default
152
+ if isinstance(value, str):
153
+ return json.loads(value) if value else default
154
+ return value
155
+ except Exception:
156
+ return default
157
+
158
+ def _metrics_key(entry: Dict[str, Any]):
159
+ if isinstance(entry, dict):
160
+ return (entry.get('step'), entry.get('timestamp'))
161
+ return (None, json.dumps(entry, sort_keys=True))
162
+
163
+ # Load existing experiments for union merge
164
+ existing = {}
165
+ try:
166
+ for row in self.load_existing_experiments():
167
+ exp_id = row.get('experiment_id')
168
+ if exp_id:
169
+ existing[exp_id] = row
170
+ except Exception:
171
+ existing = {}
172
+
173
+ # Validate and merge
174
+ merged_map: Dict[str, Dict[str, Any]] = {}
175
+ # Seed with existing
176
+ for exp_id, row in existing.items():
177
+ merged_map[exp_id] = row
178
+
179
+ # Apply incoming
180
  for exp in experiments:
181
+ if not self._validate_experiment_structure(exp):
 
 
 
 
 
182
  logger.error(f"❌ Invalid experiment structure: {exp.get('experiment_id', 'unknown')}")
183
  return False
184
+ exp_id = exp['experiment_id']
185
+ incoming = exp
186
+ if exp_id not in merged_map:
187
+ incoming['last_updated'] = incoming.get('last_updated') or datetime.now().isoformat()
188
+ merged_map[exp_id] = incoming
189
+ continue
190
+ # Merge with existing
191
+ base = merged_map[exp_id]
192
+ # Parse JSON fields
193
+ base_metrics = _parse_json_field(base.get('metrics'), [])
194
+ base_params = _parse_json_field(base.get('parameters'), {})
195
+ base_artifacts = _parse_json_field(base.get('artifacts'), [])
196
+ base_logs = _parse_json_field(base.get('logs'), [])
197
+ inc_metrics = _parse_json_field(incoming.get('metrics'), [])
198
+ inc_params = _parse_json_field(incoming.get('parameters'), {})
199
+ inc_artifacts = _parse_json_field(incoming.get('artifacts'), [])
200
+ inc_logs = _parse_json_field(incoming.get('logs'), [])
201
+ # Merge metrics with de-dup
202
+ merged_metrics = []
203
+ seen = set()
204
+ for entry in base_metrics + inc_metrics:
205
+ try:
206
+ # Use the original entry so _metrics_key can properly
207
+ # distinguish dict vs non-dict entries
208
+ key = _metrics_key(entry)
209
+ except Exception:
210
+ key = (None, None)
211
+ if key not in seen:
212
+ seen.add(key)
213
+ merged_metrics.append(entry)
214
+ # Merge params
215
+ merged_params = {}
216
+ if isinstance(base_params, dict):
217
+ merged_params.update(base_params)
218
+ if isinstance(inc_params, dict):
219
+ merged_params.update(inc_params)
220
+ # Merge artifacts and logs with de-dup
221
+ def _dedup_list(lst):
222
+ out = []
223
+ seen_local = set()
224
+ for item in lst:
225
+ key = json.dumps(item, sort_keys=True, default=str) if not isinstance(item, str) else item
226
+ if key not in seen_local:
227
+ seen_local.add(key)
228
+ out.append(item)
229
+ return out
230
+ merged_artifacts = _dedup_list(list(base_artifacts) + list(inc_artifacts))
231
+ merged_logs = _dedup_list(list(base_logs) + list(inc_logs))
232
+ # Rebuild merged record preferring incoming scalars
233
+ merged_rec = dict(base)
234
+ merged_rec.update({k: v for k, v in incoming.items() if k not in ('metrics', 'parameters', 'artifacts', 'logs')})
235
+ merged_rec['metrics'] = json.dumps(merged_metrics, default=str)
236
+ merged_rec['parameters'] = json.dumps(merged_params, default=str)
237
+ merged_rec['artifacts'] = json.dumps(merged_artifacts, default=str)
238
+ merged_rec['logs'] = json.dumps(merged_logs, default=str)
239
+ merged_rec['last_updated'] = datetime.now().isoformat()
240
+ merged_map[exp_id] = merged_rec
241
+
242
+ # Prepare final list
243
+ valid_experiments = list(merged_map.values())
244
+ # Ensure all have mandatory fields encoded
245
+ normalized = []
246
+ for rec in valid_experiments:
247
+ # Normalize json fields to strings
248
+ for f, default in (('metrics', []), ('parameters', {}), ('artifacts', []), ('logs', [])):
249
+ val = rec.get(f)
250
+ if not isinstance(val, str):
251
+ rec[f] = json.dumps(val if val is not None else default, default=str)
252
+ if 'last_updated' not in rec:
253
+ rec['last_updated'] = datetime.now().isoformat()
254
+ normalized.append(rec)
255
+
256
+ dataset = Dataset.from_list(normalized)
257
 
258
  # Generate commit message if not provided
259
  if not commit_message:
260
+ commit_message = f"Union-merge update with {len(normalized)} experiments ({datetime.now().isoformat()})"
261
 
262
  # Push to hub
263
  dataset.push_to_hub(
 
267
  commit_message=commit_message
268
  )
269
 
270
+ logger.info(f"✅ Successfully saved {len(normalized)} experiments (union-merged) to {self.dataset_repo}")
271
  return True
272
 
273
  except Exception as e: