Tonic commited on
Commit
d455d12
·
1 Parent(s): 71db310

adds scheduler stuff and hopes for the best with track tonic

Browse files
config/train_gpt_oss_custom.py CHANGED
@@ -59,7 +59,7 @@ class GPTOSSEnhancedCustomConfig:
59
  # ============================================================================
60
  # SCHEDULER CONFIGURATION
61
  # ============================================================================
62
- scheduler: str = "cosine_with_min_lr" # "linear", "cosine", "cosine_with_min_lr", "constant"
63
  lr_scheduler_kwargs: Optional[Dict] = None
64
 
65
  # ============================================================================
@@ -299,7 +299,8 @@ class GPTOSSEnhancedCustomConfig:
299
  # SCHEDULER CONFIGURATION DEFAULTS
300
  # ============================================================================
301
  if self.lr_scheduler_kwargs is None:
302
- self.lr_scheduler_kwargs = {"min_lr_rate": 0.1}
 
303
 
304
  # ============================================================================
305
  # CHAT TEMPLATE CONFIGURATION DEFAULTS (GPT-OSS Harmony Format)
 
59
  # ============================================================================
60
  # SCHEDULER CONFIGURATION
61
  # ============================================================================
62
+ scheduler: str = "cosine" # Default to broadly compatible scheduler; TRL special is opt-in
63
  lr_scheduler_kwargs: Optional[Dict] = None
64
 
65
  # ============================================================================
 
299
  # SCHEDULER CONFIGURATION DEFAULTS
300
  # ============================================================================
301
  if self.lr_scheduler_kwargs is None:
302
+ # Leave empty; training script will add TRL-specific keys only when needed
303
+ self.lr_scheduler_kwargs = {}
304
 
305
  # ============================================================================
306
  # CHAT TEMPLATE CONFIGURATION DEFAULTS (GPT-OSS Harmony Format)
config/train_gpt_oss_medical_o1_sft.py CHANGED
@@ -65,6 +65,10 @@ config = GPTOSSEnhancedCustomConfig(
65
  warmup_steps=50,
66
  max_grad_norm=1.0,
67
 
 
 
 
 
68
  # Sequence length
69
  max_seq_length=2048,
70
 
 
65
  warmup_steps=50,
66
  max_grad_norm=1.0,
67
 
68
+ # Scheduler: use broadly compatible cosine by default to avoid TRL signature issues
69
+ scheduler="cosine",
70
+ lr_scheduler_kwargs={},
71
+
72
  # Sequence length
73
  max_seq_length=2048,
74
 
config/train_gpt_oss_openhermes_fr_memory_optimized.py CHANGED
@@ -193,11 +193,9 @@ config = GPTOSSEnhancedCustomConfig(
193
  beta2=0.95, # GPT-OSS optimized beta2
194
  eps=1e-8,
195
 
196
- scheduler="cosine_with_min_lr", # Stable scheduler for single epoch
197
- lr_scheduler_kwargs={
198
- "min_lr": 2e-6, # Explicit absolute floor (matches min_lr above)
199
- "warmup_steps": None, # Use warmup_ratio instead
200
- },
201
 
202
  # Packing to increase token utilization per step (supported by TRL)
203
  packing=True,
 
193
  beta2=0.95, # GPT-OSS optimized beta2
194
  eps=1e-8,
195
 
196
+ # Use standard cosine for broad compatibility; TRL min-lr scheduler is optional
197
+ scheduler="cosine",
198
+ lr_scheduler_kwargs={},
 
 
199
 
200
  # Packing to increase token utilization per step (supported by TRL)
201
  packing=True,
requirements/requirements.txt CHANGED
@@ -19,7 +19,8 @@ numpy>=1.24.0
19
  tqdm>=4.65.0
20
 
21
  # Experiment tracking
22
- trackio>=0.1.0
 
23
 
24
  # Optional: for evaluation (commented out to reduce conflicts)
25
  # lighteval>=0.1.0
 
19
  tqdm>=4.65.0
20
 
21
  # Experiment tracking
22
+ # trackio>=0.1.0
23
+ gradio>=5.0.0
24
 
25
  # Optional: for evaluation (commented out to reduce conflicts)
26
  # lighteval>=0.1.0
requirements/requirements_minimal.txt CHANGED
@@ -10,5 +10,4 @@ tokenizers>=0.13.0
10
  bitsandbytes>=0.41.0
11
  numpy>=1.24.0
12
  tqdm>=4.65.0
13
- trackio>=0.1.0
14
  psutil>=5.9.0
 
10
  bitsandbytes>=0.41.0
11
  numpy>=1.24.0
12
  tqdm>=4.65.0
 
13
  psutil>=5.9.0
scripts/training/train_gpt_oss.py CHANGED
@@ -191,12 +191,26 @@ def load_dataset_from_config(config):
191
  return dataset
192
 
193
  def build_scheduler_kwargs(config):
194
- """Construct lr_scheduler_kwargs ensuring one of min_lr or min_lr_rate is set.
195
- Falls back to config.min_lr or a default rate of 0.1.
 
 
 
196
  """
197
  skw = getattr(config, 'lr_scheduler_kwargs', {}) or {}
198
  if not isinstance(skw, dict):
199
  skw = {}
 
 
 
 
 
 
 
 
 
 
 
200
  min_lr_cfg = getattr(config, 'min_lr', 1e-6)
201
  if 'min_lr' not in skw and 'min_lr_rate' not in skw:
202
  try:
@@ -206,6 +220,7 @@ def build_scheduler_kwargs(config):
206
  skw['min_lr_rate'] = 0.1
207
  except Exception:
208
  skw['min_lr_rate'] = 0.001
 
209
  # Remove warmup-related keys which conflict with some TRL schedulers
210
  for k in ('warmup_steps', 'num_warmup_steps', 'warmup_ratio'):
211
  if k in skw:
@@ -683,7 +698,8 @@ def create_sft_config(config, output_dir):
683
 
684
  # Learning rate configuration
685
  learning_rate = _as_float(getattr(config, 'learning_rate', 2e-4), 2e-4)
686
- lr_scheduler_type = getattr(config, 'scheduler', 'cosine_with_min_lr')
 
687
  lr_scheduler_kwargs = build_scheduler_kwargs(config)
688
 
689
  # Detect TRL scheduler signature incompatibilities and fall back gracefully
@@ -865,6 +881,57 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
865
  config.experiment_name = experiment_name
866
  config.trackio_url = trackio_url
867
  config.trainer_type = trainer_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868
 
869
  # Load model and tokenizer
870
  model, tokenizer = load_gpt_oss_model_and_tokenizer(config)
@@ -1027,6 +1094,24 @@ def main():
1027
  parser.add_argument("--output-dir", required=True, help="Output directory for checkpoints")
1028
  parser.add_argument("--trackio-url", help="Trackio URL for monitoring")
1029
  parser.add_argument("--trainer-type", default="sft", choices=["sft", "dpo"], help="Trainer type")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1030
 
1031
  args = parser.parse_args()
1032
 
@@ -1039,7 +1124,16 @@ def main():
1039
  os.makedirs(args.output_dir, exist_ok=True)
1040
 
1041
  try:
1042
- train_gpt_oss(
 
 
 
 
 
 
 
 
 
1043
  config_path=args.config,
1044
  experiment_name=args.experiment_name,
1045
  output_dir=args.output_dir,
 
191
  return dataset
192
 
193
  def build_scheduler_kwargs(config):
194
+ """Construct lr_scheduler_kwargs compatibly across TRL/Transformers versions.
195
+
196
+ - For TRL's 'cosine_with_min_lr' scheduler, ensure a min_lr/min_lr_rate is set.
197
+ - For all other schedulers, strip TRL-specific keys to avoid unexpected kwargs
198
+ errors in Transformers' native schedulers.
199
  """
200
  skw = getattr(config, 'lr_scheduler_kwargs', {}) or {}
201
  if not isinstance(skw, dict):
202
  skw = {}
203
+
204
+ scheduler_type = getattr(config, 'scheduler', None)
205
+
206
+ # If we're NOT using TRL's special scheduler, drop incompatible keys early
207
+ if scheduler_type != 'cosine_with_min_lr':
208
+ for k in ('min_lr', 'min_lr_rate', 'warmup_steps', 'num_warmup_steps', 'warmup_ratio'):
209
+ if k in skw:
210
+ skw.pop(k, None)
211
+ return skw
212
+
213
+ # TRL cosine-with-min-lr: ensure one of min_lr or min_lr_rate is provided
214
  min_lr_cfg = getattr(config, 'min_lr', 1e-6)
215
  if 'min_lr' not in skw and 'min_lr_rate' not in skw:
216
  try:
 
220
  skw['min_lr_rate'] = 0.1
221
  except Exception:
222
  skw['min_lr_rate'] = 0.001
223
+
224
  # Remove warmup-related keys which conflict with some TRL schedulers
225
  for k in ('warmup_steps', 'num_warmup_steps', 'warmup_ratio'):
226
  if k in skw:
 
698
 
699
  # Learning rate configuration
700
  learning_rate = _as_float(getattr(config, 'learning_rate', 2e-4), 2e-4)
701
+ # Allow CLI/env override of scheduler
702
+ lr_scheduler_type = os.environ.get('GPT_OSS_SCHEDULER', getattr(config, 'scheduler', 'cosine'))
703
  lr_scheduler_kwargs = build_scheduler_kwargs(config)
704
 
705
  # Detect TRL scheduler signature incompatibilities and fall back gracefully
 
881
  config.experiment_name = experiment_name
882
  config.trackio_url = trackio_url
883
  config.trainer_type = trainer_type
884
+
885
+ # Optional: scheduler overrides via environment variables set by CLI
886
+ try:
887
+ env_scheduler = os.environ.get("GPT_OSS_SCHEDULER")
888
+ if env_scheduler:
889
+ # Apply scheduler override
890
+ config.scheduler = env_scheduler
891
+ # Prepare/normalize lr scheduler kwargs container
892
+ if not hasattr(config, 'lr_scheduler_kwargs') or config.lr_scheduler_kwargs is None:
893
+ config.lr_scheduler_kwargs = {}
894
+
895
+ # Apply min lr overrides only when using TRL's special scheduler
896
+ if env_scheduler == 'cosine_with_min_lr':
897
+ env_min_lr = os.environ.get("GPT_OSS_MIN_LR")
898
+ env_min_lr_rate = os.environ.get("GPT_OSS_MIN_LR_RATE")
899
+ # Clear conflicting warmup keys to avoid signature issues
900
+ for k in ('warmup_steps', 'num_warmup_steps', 'warmup_ratio'):
901
+ if k in config.lr_scheduler_kwargs:
902
+ config.lr_scheduler_kwargs.pop(k, None)
903
+ # Prefer absolute min_lr if provided
904
+ if env_min_lr is not None:
905
+ try:
906
+ config.min_lr = float(env_min_lr)
907
+ config.lr_scheduler_kwargs['min_lr'] = config.min_lr
908
+ # Remove relative rate if present
909
+ config.lr_scheduler_kwargs.pop('min_lr_rate', None)
910
+ except Exception:
911
+ pass
912
+ elif env_min_lr_rate is not None:
913
+ try:
914
+ config.lr_scheduler_kwargs['min_lr_rate'] = float(env_min_lr_rate)
915
+ # Remove absolute min_lr if present in kwargs (leave config.min_lr untouched)
916
+ config.lr_scheduler_kwargs.pop('min_lr', None)
917
+ except Exception:
918
+ pass
919
+ else:
920
+ # Ensure at least one constraint exists; prefer absolute from config if valid
921
+ try:
922
+ if hasattr(config, 'min_lr') and config.min_lr is not None:
923
+ config.lr_scheduler_kwargs['min_lr'] = float(config.min_lr)
924
+ else:
925
+ config.lr_scheduler_kwargs.setdefault('min_lr_rate', 0.1)
926
+ except Exception:
927
+ config.lr_scheduler_kwargs.setdefault('min_lr_rate', 0.1)
928
+ else:
929
+ # Non-TRL scheduler: strip TRL-specific keys to avoid unexpected kwargs
930
+ if hasattr(config, 'lr_scheduler_kwargs') and isinstance(config.lr_scheduler_kwargs, dict):
931
+ for k in ('min_lr', 'min_lr_rate'):
932
+ config.lr_scheduler_kwargs.pop(k, None)
933
+ except Exception:
934
+ pass
935
 
936
  # Load model and tokenizer
937
  model, tokenizer = load_gpt_oss_model_and_tokenizer(config)
 
1094
  parser.add_argument("--output-dir", required=True, help="Output directory for checkpoints")
1095
  parser.add_argument("--trackio-url", help="Trackio URL for monitoring")
1096
  parser.add_argument("--trainer-type", default="sft", choices=["sft", "dpo"], help="Trainer type")
1097
+ # Optional LR scheduler overrides (applied across any GPT-OSS config)
1098
+ parser.add_argument(
1099
+ "--scheduler",
1100
+ choices=["linear", "cosine", "cosine_with_min_lr", "constant"],
1101
+ help="Override LR scheduler for this run",
1102
+ )
1103
+ parser.add_argument(
1104
+ "--min-lr",
1105
+ type=float,
1106
+ dest="min_lr",
1107
+ help="Absolute floor for LR (used when scheduler is 'cosine_with_min_lr')",
1108
+ )
1109
+ parser.add_argument(
1110
+ "--min-lr-rate",
1111
+ type=float,
1112
+ dest="min_lr_rate",
1113
+ help="Relative LR floor rate in (0,1) for TRL scheduler (used when scheduler is 'cosine_with_min_lr')",
1114
+ )
1115
 
1116
  args = parser.parse_args()
1117
 
 
1124
  os.makedirs(args.output_dir, exist_ok=True)
1125
 
1126
  try:
1127
+ # If provided, expose scheduler overrides via environment so they can be picked up consistently
1128
+ # across helper functions if needed.
1129
+ if args.scheduler:
1130
+ os.environ["GPT_OSS_SCHEDULER"] = args.scheduler
1131
+ if args.min_lr is not None:
1132
+ os.environ["GPT_OSS_MIN_LR"] = str(args.min_lr)
1133
+ if args.min_lr_rate is not None:
1134
+ os.environ["GPT_OSS_MIN_LR_RATE"] = str(args.min_lr_rate)
1135
+
1136
+ trainer = train_gpt_oss(
1137
  config_path=args.config,
1138
  experiment_name=args.experiment_name,
1139
  output_dir=args.output_dir,