Tonic commited on
Commit
3331c7f
·
1 Parent(s): fa9560d

adds optimizations for faster training

Browse files
config/train_gpt_oss_openhermes_fr_memory_optimized.py CHANGED
@@ -56,7 +56,7 @@ config = GPTOSSEnhancedCustomConfig(
56
  # MODEL CONFIGURATION - Memory Optimized for GPT-OSS
57
  # ============================================================================
58
  model_name="openai/gpt-oss-20b",
59
- max_seq_length=4096, # Maximize sequence length for A100 VRAM utilization
60
  use_flash_attention=True, # Critical for memory efficiency
61
  use_gradient_checkpointing=True, # Essential for memory optimization
62
 
@@ -115,9 +115,10 @@ config = GPTOSSEnhancedCustomConfig(
115
  },
116
 
117
  # Data loading optimized for throughput
118
- dataloader_num_workers=4, # More workers for faster loading
119
  dataloader_pin_memory=True, # Pin memory for faster host->GPU copies
120
- dataloader_prefetch_factor=1, # Lower prefetch to keep VRAM headroom
 
121
 
122
  # Memory management optimizations
123
  max_memory_per_gpu=None, # No explicit memory limit; use as much VRAM as available
@@ -197,6 +198,9 @@ config = GPTOSSEnhancedCustomConfig(
197
  "min_lr": 2e-6, # Explicit absolute floor (matches min_lr above)
198
  "warmup_steps": None, # Use warmup_ratio instead
199
  },
 
 
 
200
 
201
  # ============================================================================
202
  # MONITORING & HUB INTEGRATION
 
56
  # MODEL CONFIGURATION - Memory Optimized for GPT-OSS
57
  # ============================================================================
58
  model_name="openai/gpt-oss-20b",
59
+ max_seq_length=2048, # Shorter context speeds steps without reducing sample count
60
  use_flash_attention=True, # Critical for memory efficiency
61
  use_gradient_checkpointing=True, # Essential for memory optimization
62
 
 
115
  },
116
 
117
  # Data loading optimized for throughput
118
+ dataloader_num_workers=8, # More workers for faster loading
119
  dataloader_pin_memory=True, # Pin memory for faster host->GPU copies
120
+ dataloader_prefetch_factor=2, # Slightly higher prefetch for throughput
121
+ dataset_num_proc=8, # Parallelize HF datasets map/filter
122
 
123
  # Memory management optimizations
124
  max_memory_per_gpu=None, # No explicit memory limit; use as much VRAM as available
 
198
  "min_lr": 2e-6, # Explicit absolute floor (matches min_lr above)
199
  "warmup_steps": None, # Use warmup_ratio instead
200
  },
201
+
202
+ # Packing to increase token utilization per step (supported by TRL)
203
+ packing=True,
204
 
205
  # ============================================================================
206
  # MONITORING & HUB INTEGRATION
scripts/training/train_gpt_oss.py CHANGED
@@ -210,6 +210,13 @@ def build_scheduler_kwargs(config):
210
  def apply_dataset_filtering(dataset, config):
211
  """Apply filtering based on configuration"""
212
 
 
 
 
 
 
 
 
213
  # Filter bad entries if specified
214
  if getattr(config, 'filter_bad_entries', False):
215
  bad_entry_field = getattr(config, 'bad_entry_field', 'bad_entry')
@@ -220,17 +227,23 @@ def apply_dataset_filtering(dataset, config):
220
 
221
  # Filter out bad entries
222
  if bad_entry_field in dataset.column_names:
223
- dataset = dataset.filter(lambda x: not x.get(bad_entry_field, False))
 
 
224
  print(f"Filtered {original_size - len(dataset)} bad entries")
225
 
226
  # Filter out bad prompts
227
  if bad_prompt_field in dataset.column_names:
228
- dataset = dataset.filter(lambda x: not x.get(bad_prompt_field, False))
 
 
229
  print(f"Filtered bad prompts, remaining: {len(dataset)} examples")
230
 
231
  # Filter out bad responses
232
  if bad_response_field in dataset.column_names:
233
- dataset = dataset.filter(lambda x: not x.get(bad_response_field, False))
 
 
234
  print(f"Filtered bad responses, remaining: {len(dataset)} examples")
235
 
236
  # Apply length filtering
@@ -253,7 +266,7 @@ def apply_dataset_filtering(dataset, config):
253
  return True
254
 
255
  original_size = len(dataset)
256
- dataset = dataset.filter(length_filter)
257
  print(f"Length filtering: {original_size} -> {len(dataset)} examples")
258
 
259
  # Apply sampling if specified
@@ -293,6 +306,13 @@ def format_gpt_oss_harmony_prompt(prompt: str) -> str:
293
  def process_dataset_format(dataset, config):
294
  """Process dataset based on format configuration with exact GPT-OSS Harmony compliance"""
295
 
 
 
 
 
 
 
 
296
  dataset_format = getattr(config, 'dataset_format', 'openhermes_fr')
297
  input_field = getattr(config, 'input_field', 'prompt')
298
  target_field = getattr(config, 'target_field', 'accepted_completion')
@@ -325,7 +345,7 @@ def process_dataset_format(dataset, config):
325
  return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val}
326
 
327
  keep_cols = [c for c in ['prompt', 'chosen', 'rejected'] if c in dataset.column_names]
328
- dataset = dataset.map(id_map, remove_columns=dataset.column_names if keep_cols else dataset.column_names)
329
  return dataset
330
 
331
  # Custom preference mapping via configured field names
@@ -341,7 +361,7 @@ def process_dataset_format(dataset, config):
341
  return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
342
  return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val}
343
 
344
- dataset = dataset.map(to_pref, remove_columns=dataset.column_names)
345
  return dataset
346
 
347
  # If we reach here, we don't have required fields for DPO
@@ -371,7 +391,7 @@ def process_dataset_format(dataset, config):
371
  "output": completion
372
  }
373
 
374
- dataset = dataset.map(format_openhermes_fr, remove_columns=dataset.column_names)
375
 
376
  elif dataset_format == "messages":
377
  # Process messages format (like HuggingFaceH4/Multilingual-Thinking)
@@ -416,7 +436,7 @@ def process_dataset_format(dataset, config):
416
 
417
  return {"text": text}
418
 
419
- dataset = dataset.map(format_messages, remove_columns=dataset.column_names)
420
 
421
  elif dataset_format == "text":
422
  # Process plain text format
@@ -427,7 +447,7 @@ def process_dataset_format(dataset, config):
427
  text += "</s>"
428
  return {"text": text}
429
 
430
- dataset = dataset.map(format_text, remove_columns=dataset.column_names)
431
 
432
  elif dataset_format == "custom":
433
  # Custom format - user handles this in their config
@@ -652,6 +672,8 @@ def create_sft_config(config, output_dir):
652
  "bf16": bf16,
653
  # Some versions support tf32
654
  "tf32": tf32 if 'tf32' in TrainingArguments.__init__.__code__.co_varnames else None,
 
 
655
  # Regularization
656
  "weight_decay": weight_decay,
657
  "max_grad_norm": max_grad_norm,
@@ -828,6 +850,10 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
828
  if "max_seq_length" in sft_params:
829
  sft_kwargs["max_seq_length"] = getattr(config, 'max_seq_length', 2048)
830
 
 
 
 
 
831
  # Remove any None values
832
  sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None}
833
 
 
210
  def apply_dataset_filtering(dataset, config):
211
  """Apply filtering based on configuration"""
212
 
213
+ # Parallel workers for datasets ops
214
+ try:
215
+ import os as _os
216
+ num_proc = getattr(config, 'dataset_num_proc', None) or (_os.cpu_count() or 1)
217
+ except Exception:
218
+ num_proc = 1
219
+
220
  # Filter bad entries if specified
221
  if getattr(config, 'filter_bad_entries', False):
222
  bad_entry_field = getattr(config, 'bad_entry_field', 'bad_entry')
 
227
 
228
  # Filter out bad entries
229
  if bad_entry_field in dataset.column_names:
230
+ def _keep_not_bad_entry(example, _field=bad_entry_field):
231
+ return not example.get(_field, False)
232
+ dataset = dataset.filter(_keep_not_bad_entry, num_proc=num_proc)
233
  print(f"Filtered {original_size - len(dataset)} bad entries")
234
 
235
  # Filter out bad prompts
236
  if bad_prompt_field in dataset.column_names:
237
+ def _keep_not_bad_prompt(example, _field=bad_prompt_field):
238
+ return not example.get(_field, False)
239
+ dataset = dataset.filter(_keep_not_bad_prompt, num_proc=num_proc)
240
  print(f"Filtered bad prompts, remaining: {len(dataset)} examples")
241
 
242
  # Filter out bad responses
243
  if bad_response_field in dataset.column_names:
244
+ def _keep_not_bad_response(example, _field=bad_response_field):
245
+ return not example.get(_field, False)
246
+ dataset = dataset.filter(_keep_not_bad_response, num_proc=num_proc)
247
  print(f"Filtered bad responses, remaining: {len(dataset)} examples")
248
 
249
  # Apply length filtering
 
266
  return True
267
 
268
  original_size = len(dataset)
269
+ dataset = dataset.filter(length_filter, num_proc=num_proc)
270
  print(f"Length filtering: {original_size} -> {len(dataset)} examples")
271
 
272
  # Apply sampling if specified
 
306
  def process_dataset_format(dataset, config):
307
  """Process dataset based on format configuration with exact GPT-OSS Harmony compliance"""
308
 
309
+ # Parallel workers for datasets ops
310
+ try:
311
+ import os as _os
312
+ num_proc = getattr(config, 'dataset_num_proc', None) or (_os.cpu_count() or 1)
313
+ except Exception:
314
+ num_proc = 1
315
+
316
  dataset_format = getattr(config, 'dataset_format', 'openhermes_fr')
317
  input_field = getattr(config, 'input_field', 'prompt')
318
  target_field = getattr(config, 'target_field', 'accepted_completion')
 
345
  return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val}
346
 
347
  keep_cols = [c for c in ['prompt', 'chosen', 'rejected'] if c in dataset.column_names]
348
+ dataset = dataset.map(id_map, remove_columns=dataset.column_names if keep_cols else dataset.column_names, num_proc=num_proc)
349
  return dataset
350
 
351
  # Custom preference mapping via configured field names
 
361
  return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
362
  return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val}
363
 
364
+ dataset = dataset.map(to_pref, remove_columns=dataset.column_names, num_proc=num_proc)
365
  return dataset
366
 
367
  # If we reach here, we don't have required fields for DPO
 
391
  "output": completion
392
  }
393
 
394
+ dataset = dataset.map(format_openhermes_fr, remove_columns=dataset.column_names, num_proc=num_proc)
395
 
396
  elif dataset_format == "messages":
397
  # Process messages format (like HuggingFaceH4/Multilingual-Thinking)
 
436
 
437
  return {"text": text}
438
 
439
+ dataset = dataset.map(format_messages, remove_columns=dataset.column_names, num_proc=num_proc)
440
 
441
  elif dataset_format == "text":
442
  # Process plain text format
 
447
  text += "</s>"
448
  return {"text": text}
449
 
450
+ dataset = dataset.map(format_text, remove_columns=dataset.column_names, num_proc=num_proc)
451
 
452
  elif dataset_format == "custom":
453
  # Custom format - user handles this in their config
 
672
  "bf16": bf16,
673
  # Some versions support tf32
674
  "tf32": tf32 if 'tf32' in TrainingArguments.__init__.__code__.co_varnames else None,
675
+ # Optimizer (optionally use fused AdamW if available through config)
676
+ "optim": getattr(config, 'optimizer', 'adamw_torch'),
677
  # Regularization
678
  "weight_decay": weight_decay,
679
  "max_grad_norm": max_grad_norm,
 
850
  if "max_seq_length" in sft_params:
851
  sft_kwargs["max_seq_length"] = getattr(config, 'max_seq_length', 2048)
852
 
853
+ # Enable sequence packing if supported by TRL (speeds up token utilization)
854
+ if "packing" in sft_params:
855
+ sft_kwargs["packing"] = getattr(config, 'packing', False)
856
+
857
  # Remove any None values
858
  sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None}
859
 
src/dataset_utils.py CHANGED
@@ -122,12 +122,20 @@ class TrackioDatasetManager:
122
 
123
  def save_experiments(self, experiments: List[Dict[str, Any]], commit_message: Optional[str] = None) -> bool:
124
  """
125
- Save a list of experiments to the dataset, preserving data integrity.
126
-
 
 
 
 
 
 
 
 
127
  Args:
128
  experiments (List[Dict[str, Any]]): List of experiment dictionaries
129
  commit_message (Optional[str]): Custom commit message
130
-
131
  Returns:
132
  bool: True if save was successful, False otherwise
133
  """
@@ -136,24 +144,120 @@ class TrackioDatasetManager:
136
  logger.warning("⚠️ No experiments to save")
137
  return False
138
 
139
- # Validate all experiments before saving
140
- valid_experiments = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  for exp in experiments:
142
- if self._validate_experiment_structure(exp):
143
- # Ensure last_updated is set
144
- if 'last_updated' not in exp:
145
- exp['last_updated'] = datetime.now().isoformat()
146
- valid_experiments.append(exp)
147
- else:
148
  logger.error(f"❌ Invalid experiment structure: {exp.get('experiment_id', 'unknown')}")
149
  return False
150
-
151
- # Create dataset
152
- dataset = Dataset.from_list(valid_experiments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  # Generate commit message if not provided
155
  if not commit_message:
156
- commit_message = f"Update dataset with {len(valid_experiments)} experiments ({datetime.now().isoformat()})"
157
 
158
  # Push to hub
159
  dataset.push_to_hub(
@@ -163,7 +267,7 @@ class TrackioDatasetManager:
163
  commit_message=commit_message
164
  )
165
 
166
- logger.info(f"✅ Successfully saved {len(valid_experiments)} experiments to {self.dataset_repo}")
167
  return True
168
 
169
  except Exception as e:
 
122
 
123
  def save_experiments(self, experiments: List[Dict[str, Any]], commit_message: Optional[str] = None) -> bool:
124
  """
125
+ Save a list of experiments to the dataset using a non-destructive union merge.
126
+
127
+ - Loads existing experiments (if any) and builds a union by `experiment_id`.
128
+ - For overlapping IDs, merges JSON fields:
129
+ - metrics: concatenates lists and de-duplicates by (step, timestamp) for nested entries
130
+ - parameters: dict-update (new values override)
131
+ - artifacts: union with de-dup
132
+ - logs: concatenation with de-dup
133
+ - Non-JSON scalar fields from incoming experiments take precedence.
134
+
135
  Args:
136
  experiments (List[Dict[str, Any]]): List of experiment dictionaries
137
  commit_message (Optional[str]): Custom commit message
138
+
139
  Returns:
140
  bool: True if save was successful, False otherwise
141
  """
 
144
  logger.warning("⚠️ No experiments to save")
145
  return False
146
 
147
+ # Helpers
148
+ def _parse_json_field(value, default):
149
+ try:
150
+ if value is None:
151
+ return default
152
+ if isinstance(value, str):
153
+ return json.loads(value) if value else default
154
+ return value
155
+ except Exception:
156
+ return default
157
+
158
+ def _metrics_key(entry: Dict[str, Any]):
159
+ if isinstance(entry, dict):
160
+ return (entry.get('step'), entry.get('timestamp'))
161
+ return (None, json.dumps(entry, sort_keys=True))
162
+
163
+ # Load existing experiments for union merge
164
+ existing = {}
165
+ try:
166
+ for row in self.load_existing_experiments():
167
+ exp_id = row.get('experiment_id')
168
+ if exp_id:
169
+ existing[exp_id] = row
170
+ except Exception:
171
+ existing = {}
172
+
173
+ # Validate and merge
174
+ merged_map: Dict[str, Dict[str, Any]] = {}
175
+ # Seed with existing
176
+ for exp_id, row in existing.items():
177
+ merged_map[exp_id] = row
178
+
179
+ # Apply incoming
180
  for exp in experiments:
181
+ if not self._validate_experiment_structure(exp):
 
 
 
 
 
182
  logger.error(f"❌ Invalid experiment structure: {exp.get('experiment_id', 'unknown')}")
183
  return False
184
+ exp_id = exp['experiment_id']
185
+ incoming = exp
186
+ if exp_id not in merged_map:
187
+ incoming['last_updated'] = incoming.get('last_updated') or datetime.now().isoformat()
188
+ merged_map[exp_id] = incoming
189
+ continue
190
+ # Merge with existing
191
+ base = merged_map[exp_id]
192
+ # Parse JSON fields
193
+ base_metrics = _parse_json_field(base.get('metrics'), [])
194
+ base_params = _parse_json_field(base.get('parameters'), {})
195
+ base_artifacts = _parse_json_field(base.get('artifacts'), [])
196
+ base_logs = _parse_json_field(base.get('logs'), [])
197
+ inc_metrics = _parse_json_field(incoming.get('metrics'), [])
198
+ inc_params = _parse_json_field(incoming.get('parameters'), {})
199
+ inc_artifacts = _parse_json_field(incoming.get('artifacts'), [])
200
+ inc_logs = _parse_json_field(incoming.get('logs'), [])
201
+ # Merge metrics with de-dup
202
+ merged_metrics = []
203
+ seen = set()
204
+ for entry in base_metrics + inc_metrics:
205
+ try:
206
+ # Use the original entry so _metrics_key can properly
207
+ # distinguish dict vs non-dict entries
208
+ key = _metrics_key(entry)
209
+ except Exception:
210
+ key = (None, None)
211
+ if key not in seen:
212
+ seen.add(key)
213
+ merged_metrics.append(entry)
214
+ # Merge params
215
+ merged_params = {}
216
+ if isinstance(base_params, dict):
217
+ merged_params.update(base_params)
218
+ if isinstance(inc_params, dict):
219
+ merged_params.update(inc_params)
220
+ # Merge artifacts and logs with de-dup
221
+ def _dedup_list(lst):
222
+ out = []
223
+ seen_local = set()
224
+ for item in lst:
225
+ key = json.dumps(item, sort_keys=True, default=str) if not isinstance(item, str) else item
226
+ if key not in seen_local:
227
+ seen_local.add(key)
228
+ out.append(item)
229
+ return out
230
+ merged_artifacts = _dedup_list(list(base_artifacts) + list(inc_artifacts))
231
+ merged_logs = _dedup_list(list(base_logs) + list(inc_logs))
232
+ # Rebuild merged record preferring incoming scalars
233
+ merged_rec = dict(base)
234
+ merged_rec.update({k: v for k, v in incoming.items() if k not in ('metrics', 'parameters', 'artifacts', 'logs')})
235
+ merged_rec['metrics'] = json.dumps(merged_metrics, default=str)
236
+ merged_rec['parameters'] = json.dumps(merged_params, default=str)
237
+ merged_rec['artifacts'] = json.dumps(merged_artifacts, default=str)
238
+ merged_rec['logs'] = json.dumps(merged_logs, default=str)
239
+ merged_rec['last_updated'] = datetime.now().isoformat()
240
+ merged_map[exp_id] = merged_rec
241
+
242
+ # Prepare final list
243
+ valid_experiments = list(merged_map.values())
244
+ # Ensure all have mandatory fields encoded
245
+ normalized = []
246
+ for rec in valid_experiments:
247
+ # Normalize json fields to strings
248
+ for f, default in (('metrics', []), ('parameters', {}), ('artifacts', []), ('logs', [])):
249
+ val = rec.get(f)
250
+ if not isinstance(val, str):
251
+ rec[f] = json.dumps(val if val is not None else default, default=str)
252
+ if 'last_updated' not in rec:
253
+ rec['last_updated'] = datetime.now().isoformat()
254
+ normalized.append(rec)
255
+
256
+ dataset = Dataset.from_list(normalized)
257
 
258
  # Generate commit message if not provided
259
  if not commit_message:
260
+ commit_message = f"Union-merge update with {len(normalized)} experiments ({datetime.now().isoformat()})"
261
 
262
  # Push to hub
263
  dataset.push_to_hub(
 
267
  commit_message=commit_message
268
  )
269
 
270
+ logger.info(f"✅ Successfully saved {len(normalized)} experiments (union-merged) to {self.dataset_repo}")
271
  return True
272
 
273
  except Exception as e: