Tonic commited on
Commit
97dacc7
·
1 Parent(s): 598357a

adds defensive programming (boo) and adaptations based on transformer versions

Browse files
Files changed (1) hide show
  1. scripts/training/train_gpt_oss.py +54 -43
scripts/training/train_gpt_oss.py CHANGED
@@ -8,6 +8,7 @@ Based on the GPT-OSS fine-tuning tutorial
8
  import os
9
  import sys
10
  import argparse
 
11
  import torch
12
  from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
13
  from peft import LoraConfig, get_peft_model
@@ -386,62 +387,72 @@ def create_sft_config(config, output_dir):
386
  print(f" • Gradient accumulation: {gradient_accumulation_steps}")
387
  print(f" • Effective batch size: {per_device_train_batch_size * gradient_accumulation_steps}")
388
 
389
- sft_config = TrainingArguments(
 
390
  # Training duration
391
- num_train_epochs=num_train_epochs,
392
- max_steps=max_steps,
393
-
394
  # Learning rate
395
- learning_rate=learning_rate,
396
- lr_scheduler_type=lr_scheduler_type,
397
- warmup_ratio=warmup_ratio,
398
- warmup_steps=warmup_steps,
399
-
400
  # Batch configuration
401
- per_device_train_batch_size=per_device_train_batch_size,
402
- per_device_eval_batch_size=per_device_eval_batch_size,
403
- gradient_accumulation_steps=gradient_accumulation_steps,
404
-
405
  # Model configuration
406
- gradient_checkpointing=getattr(config, 'use_gradient_checkpointing', True),
407
-
408
  # Mixed precision
409
- fp16=fp16,
410
- bf16=bf16,
411
-
412
  # Regularization
413
- weight_decay=weight_decay,
414
- max_grad_norm=max_grad_norm,
415
-
416
- # Evaluation
417
- evaluation_strategy=eval_strategy,
418
- eval_steps=eval_steps,
419
-
420
  # Logging
421
- logging_steps=logging_steps,
422
-
423
  # Saving
424
- save_strategy=save_strategy,
425
- save_steps=save_steps,
426
- save_total_limit=save_total_limit,
427
-
428
  # Output
429
- output_dir=output_dir,
430
-
431
  # Data loading
432
- dataloader_num_workers=getattr(config, 'dataloader_num_workers', 4),
433
- dataloader_pin_memory=getattr(config, 'dataloader_pin_memory', True),
434
-
435
  # Performance
436
- group_by_length=getattr(config, 'group_by_length', True),
437
- remove_unused_columns=getattr(config, 'remove_unused_columns', True),
438
-
439
  # HuggingFace Hub
440
- push_to_hub=push_to_hub,
441
-
442
  # Monitoring
443
- report_to=("trackio" if getattr(config, 'enable_tracking', False) else None),
444
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  return sft_config
447
 
 
8
  import os
9
  import sys
10
  import argparse
11
+ import inspect
12
  import torch
13
  from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
14
  from peft import LoraConfig, get_peft_model
 
387
  print(f" • Gradient accumulation: {gradient_accumulation_steps}")
388
  print(f" • Effective batch size: {per_device_train_batch_size * gradient_accumulation_steps}")
389
 
390
+ # Build kwargs dynamically to be compatible across transformers versions
391
+ ta_kwargs = {
392
  # Training duration
393
+ "num_train_epochs": num_train_epochs,
394
+ "max_steps": max_steps,
 
395
  # Learning rate
396
+ "learning_rate": learning_rate,
397
+ "lr_scheduler_type": lr_scheduler_type,
398
+ "warmup_ratio": warmup_ratio,
399
+ "warmup_steps": warmup_steps,
 
400
  # Batch configuration
401
+ "per_device_train_batch_size": per_device_train_batch_size,
402
+ "per_device_eval_batch_size": per_device_eval_batch_size,
403
+ "gradient_accumulation_steps": gradient_accumulation_steps,
 
404
  # Model configuration
405
+ "gradient_checkpointing": getattr(config, 'use_gradient_checkpointing', True),
 
406
  # Mixed precision
407
+ "fp16": fp16,
408
+ "bf16": bf16,
 
409
  # Regularization
410
+ "weight_decay": weight_decay,
411
+ "max_grad_norm": max_grad_norm,
412
+ # Evaluation (name may vary across versions)
413
+ "evaluation_strategy": eval_strategy,
414
+ "eval_steps": eval_steps,
 
 
415
  # Logging
416
+ "logging_steps": logging_steps,
 
417
  # Saving
418
+ "save_strategy": save_strategy,
419
+ "save_steps": save_steps,
420
+ "save_total_limit": save_total_limit,
 
421
  # Output
422
+ "output_dir": output_dir,
 
423
  # Data loading
424
+ "dataloader_num_workers": getattr(config, 'dataloader_num_workers', 4),
425
+ "dataloader_pin_memory": getattr(config, 'dataloader_pin_memory', True),
 
426
  # Performance
427
+ "group_by_length": getattr(config, 'group_by_length', True),
428
+ "remove_unused_columns": getattr(config, 'remove_unused_columns', True),
 
429
  # HuggingFace Hub
430
+ "push_to_hub": push_to_hub,
 
431
  # Monitoring
432
+ "report_to": ("trackio" if getattr(config, 'enable_tracking', False) else None),
433
+ }
434
+
435
+ # Adapt to transformers versions where 'evaluation_strategy' was renamed
436
+ try:
437
+ ta_sig = inspect.signature(TrainingArguments.__init__)
438
+ param_names = set(ta_sig.parameters.keys())
439
+ except Exception:
440
+ param_names = set()
441
+
442
+ if "evaluation_strategy" not in param_names and "eval_strategy" in param_names:
443
+ # Move value to 'eval_strategy'
444
+ ta_kwargs["eval_strategy"] = ta_kwargs.pop("evaluation_strategy")
445
+ elif "evaluation_strategy" not in param_names:
446
+ # If neither is supported, drop it
447
+ ta_kwargs.pop("evaluation_strategy", None)
448
+
449
+ # Remove any kwargs not supported by current transformers version
450
+ if param_names:
451
+ unsupported = [k for k in ta_kwargs.keys() if k not in param_names]
452
+ for k in unsupported:
453
+ ta_kwargs.pop(k, None)
454
+
455
+ sft_config = TrainingArguments(**ta_kwargs)
456
 
457
  return sft_config
458