Tonic commited on
Commit
0fa6045
·
1 Parent(s): 665844a

adds better launch.sh and eval / test splits auto

Browse files
config/train_gpt_oss_basic.py CHANGED
@@ -62,6 +62,9 @@ class GPTOSSBasicConfig:
62
  metric_for_best_model: str = "eval_loss"
63
  greater_is_better: bool = False
64
  load_best_model_at_end: bool = True
 
 
 
65
 
66
  # Data configuration
67
  dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
@@ -99,6 +102,13 @@ class GPTOSSBasicConfig:
99
 
100
  # GPT-OSS specific model kwargs
101
  model_kwargs: dict = None
 
 
 
 
 
 
 
102
 
103
  def __post_init__(self):
104
  if self.chat_template_kwargs is None:
 
62
  metric_for_best_model: str = "eval_loss"
63
  greater_is_better: bool = False
64
  load_best_model_at_end: bool = True
65
+ eval_accumulation_steps: Optional[int] = None
66
+ eval_ratio: float = 0.01
67
+ test_ratio: float = 0.01
68
 
69
  # Data configuration
70
  dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
 
102
 
103
  # GPT-OSS specific model kwargs
104
  model_kwargs: dict = None
105
+ # Performance and precision extras
106
+ dataloader_prefetch_factor: int = 2
107
+ tf32: Optional[bool] = None
108
+ # DPO preference training fields
109
+ chosen_field: Optional[str] = None
110
+ rejected_field: Optional[str] = None
111
+ dpo_beta: float = 0.1
112
 
113
  def __post_init__(self):
114
  if self.chat_template_kwargs is None:
config/train_gpt_oss_custom.py CHANGED
@@ -83,6 +83,9 @@ class GPTOSSEnhancedCustomConfig:
83
  eval_steps: int = 100 # Evaluate every N steps
84
  eval_delay: float = 0 # Delay evaluation for N steps/epochs
85
  eval_accumulation_steps: Optional[int] = None # Accumulate eval outputs
 
 
 
86
 
87
  # Checkpointing
88
  save_strategy: str = "steps" # "no", "steps", "epoch"
@@ -167,6 +170,11 @@ class GPTOSSEnhancedCustomConfig:
167
 
168
  # Generation Configuration (for evaluation/testing)
169
  generation_config: Optional[Dict] = None
 
 
 
 
 
170
 
171
  # ============================================================================
172
  # MULTILINGUAL & DOMAIN SPECIFIC SETTINGS
 
83
  eval_steps: int = 100 # Evaluate every N steps
84
  eval_delay: float = 0 # Delay evaluation for N steps/epochs
85
  eval_accumulation_steps: Optional[int] = None # Accumulate eval outputs
86
+ # Automatic split ratios when only a single training split is provided
87
+ eval_ratio: float = 0.01 # Fraction of data for validation (0.0-0.5 typical)
88
+ test_ratio: float = 0.01 # Fraction of data for test (0.0-0.5 typical)
89
 
90
  # Checkpointing
91
  save_strategy: str = "steps" # "no", "steps", "epoch"
 
170
 
171
  # Generation Configuration (for evaluation/testing)
172
  generation_config: Optional[Dict] = None
173
+
174
+ # Preference-training (DPO) configuration
175
+ chosen_field: Optional[str] = None # Field name for preferred response (for DPO datasets)
176
+ rejected_field: Optional[str] = None # Field name for rejected response (for DPO datasets)
177
+ dpo_beta: float = 0.1 # DPO beta parameter
178
 
179
  # ============================================================================
180
  # MULTILINGUAL & DOMAIN SPECIFIC SETTINGS
config/train_gpt_oss_h100_optimized.py CHANGED
@@ -62,6 +62,9 @@ class GPTOSSH100OptimizedConfig:
62
  metric_for_best_model: str = "eval_loss"
63
  greater_is_better: bool = False
64
  load_best_model_at_end: bool = True
 
 
 
65
 
66
  # Data configuration
67
  dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
@@ -104,6 +107,10 @@ class GPTOSSH100OptimizedConfig:
104
  dataloader_num_workers: int = 8 # More workers for H100
105
  dataloader_pin_memory: bool = True
106
  dataloader_prefetch_factor: int = 4 # Increased prefetch
 
 
 
 
107
 
108
  # Memory optimizations for H100
109
  max_grad_norm: float = 1.0
 
62
  metric_for_best_model: str = "eval_loss"
63
  greater_is_better: bool = False
64
  load_best_model_at_end: bool = True
65
+ eval_accumulation_steps: Optional[int] = None
66
+ eval_ratio: float = 0.01
67
+ test_ratio: float = 0.01
68
 
69
  # Data configuration
70
  dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
 
107
  dataloader_num_workers: int = 8 # More workers for H100
108
  dataloader_pin_memory: bool = True
109
  dataloader_prefetch_factor: int = 4 # Increased prefetch
110
+ tf32: Optional[bool] = None
111
+ chosen_field: Optional[str] = None
112
+ rejected_field: Optional[str] = None
113
+ dpo_beta: float = 0.1
114
 
115
  # Memory optimizations for H100
116
  max_grad_norm: float = 1.0
config/train_gpt_oss_memory_optimized.py CHANGED
@@ -43,6 +43,9 @@ class GPTOSSMemoryOptimizedConfig:
43
  metric_for_best_model: str = "eval_loss"
44
  greater_is_better: bool = False
45
  load_best_model_at_end: bool = True
 
 
 
46
  dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
47
  dataset_split: str = "train"
48
  input_field: str = "messages"
@@ -65,6 +68,11 @@ class GPTOSSMemoryOptimizedConfig:
65
  use_quantization: bool = True
66
  quantization_config: dict = None
67
  model_kwargs: dict = None
 
 
 
 
 
68
  generation_config: dict = None
69
  reasoning_languages: list = None
70
 
 
43
  metric_for_best_model: str = "eval_loss"
44
  greater_is_better: bool = False
45
  load_best_model_at_end: bool = True
46
+ eval_accumulation_steps: Optional[int] = None
47
+ eval_ratio: float = 0.01
48
+ test_ratio: float = 0.01
49
  dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
50
  dataset_split: str = "train"
51
  input_field: str = "messages"
 
68
  use_quantization: bool = True
69
  quantization_config: dict = None
70
  model_kwargs: dict = None
71
+ dataloader_prefetch_factor: int = 2
72
+ tf32: Optional[bool] = None
73
+ chosen_field: Optional[str] = None
74
+ rejected_field: Optional[str] = None
75
+ dpo_beta: float = 0.1
76
  generation_config: dict = None
77
  reasoning_languages: list = None
78
 
config/train_gpt_oss_multilingual_reasoning.py CHANGED
@@ -62,6 +62,9 @@ class GPTOSSMultilingualReasoningConfig:
62
  metric_for_best_model: str = "eval_loss"
63
  greater_is_better: bool = False
64
  load_best_model_at_end: bool = True
 
 
 
65
 
66
  # Data configuration - Multilingual-Thinking specific
67
  dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
@@ -99,6 +102,11 @@ class GPTOSSMultilingualReasoningConfig:
99
 
100
  # GPT-OSS specific model kwargs - as per tutorial
101
  model_kwargs: dict = None
 
 
 
 
 
102
 
103
  # Multilingual reasoning specific configurations
104
  # Generation parameters for multilingual reasoning
 
62
  metric_for_best_model: str = "eval_loss"
63
  greater_is_better: bool = False
64
  load_best_model_at_end: bool = True
65
+ eval_accumulation_steps: Optional[int] = None
66
+ eval_ratio: float = 0.01
67
+ test_ratio: float = 0.01
68
 
69
  # Data configuration - Multilingual-Thinking specific
70
  dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
 
102
 
103
  # GPT-OSS specific model kwargs - as per tutorial
104
  model_kwargs: dict = None
105
+ dataloader_prefetch_factor: int = 2
106
+ tf32: Optional[bool] = None
107
+ chosen_field: Optional[str] = None
108
+ rejected_field: Optional[str] = None
109
+ dpo_beta: float = 0.1
110
 
111
  # Multilingual reasoning specific configurations
112
  # Generation parameters for multilingual reasoning
config/train_gpt_oss_openhermes_fr.py CHANGED
@@ -119,6 +119,9 @@ config = GPTOSSEnhancedCustomConfig(
119
  metric_for_best_model="eval_loss",
120
  greater_is_better=False,
121
  load_best_model_at_end=True,
 
 
 
122
 
123
  # ============================================================================
124
  # MULTILINGUAL & FRENCH SPECIFIC SETTINGS
 
119
  metric_for_best_model="eval_loss",
120
  greater_is_better=False,
121
  load_best_model_at_end=True,
122
+ # Split ratios for automatic validation/test creation
123
+ eval_ratio=0.01,
124
+ test_ratio=0.01,
125
 
126
  # ============================================================================
127
  # MULTILINGUAL & FRENCH SPECIFIC SETTINGS
config/train_gpt_oss_openhermes_fr_memory_optimized.py CHANGED
@@ -144,6 +144,9 @@ config = GPTOSSEnhancedCustomConfig(
144
  # Evaluation memory optimization
145
  eval_accumulation_steps=4, # Accumulate eval outputs to save memory
146
  eval_batch_size=1, # Smaller eval batch size
 
 
 
147
 
148
  # ============================================================================
149
  # GPT-OSS HARMONY FORMAT OPTIMIZATION
 
144
  # Evaluation memory optimization
145
  eval_accumulation_steps=4, # Accumulate eval outputs to save memory
146
  eval_batch_size=1, # Smaller eval batch size
147
+ # Split ratios for automatic validation/test creation
148
+ eval_ratio=0.001,
149
+ test_ratio=0.0005,
150
 
151
  # ============================================================================
152
  # GPT-OSS HARMONY FORMAT OPTIMIZATION
launch.sh CHANGED
@@ -827,7 +827,15 @@ fi
827
  print_step "Step 3: Experiment Details"
828
  echo "=============================="
829
 
830
- get_input "Experiment name" "smollm3_finetune_$(date +%Y%m%d_%H%M%S)" EXPERIMENT_NAME
 
 
 
 
 
 
 
 
831
 
832
  # Configure model repository name (customizable)
833
  print_info "Setting up model repository name..."
 
827
  print_step "Step 3: Experiment Details"
828
  echo "=============================="
829
 
830
+ # Derive default experiment name from smolfactory + chosen model family
831
+ if [ "$MODEL_FAMILY" = "GPT-OSS" ]; then
832
+ FAMILY_SLUG="gpt-oss"
833
+ else
834
+ FAMILY_SLUG="smollm3"
835
+ fi
836
+ DEFAULT_EXPERIMENT_NAME="smolfactory-${FAMILY_SLUG}_$(date +%Y%m%d_%H%M%S)"
837
+
838
+ get_input "Experiment name" "$DEFAULT_EXPERIMENT_NAME" EXPERIMENT_NAME
839
 
840
  # Configure model repository name (customizable)
841
  print_info "Setting up model repository name..."
scripts/training/train_gpt_oss.py CHANGED
@@ -13,6 +13,10 @@ import torch
13
  from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
14
  from peft import LoraConfig, get_peft_model
15
  from trl import SFTTrainer
 
 
 
 
16
  from datasets import load_dataset
17
  from pathlib import Path
18
 
@@ -214,6 +218,10 @@ def format_gpt_oss_harmony(prompt, completion, add_eos_token=True):
214
 
215
  return harmony_text
216
 
 
 
 
 
217
  def process_dataset_format(dataset, config):
218
  """Process dataset based on format configuration with exact GPT-OSS Harmony compliance"""
219
 
@@ -224,11 +232,53 @@ def process_dataset_format(dataset, config):
224
  field_separator = getattr(config, 'field_separator', '\n\n### Response:\n')
225
  add_eos_token = getattr(config, 'add_eos_token', True)
226
  use_harmony_format = getattr(config, 'use_harmony_format', True)
 
227
 
228
  print(f"Processing dataset format: {dataset_format}")
229
  print(f"Input field: {input_field}, Target field: {target_field}")
230
  print(f"GPT-OSS Harmony Format: {'Enabled' if use_harmony_format else 'Disabled'}")
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  if dataset_format == "openhermes_fr":
233
  # Process OpenHermes-FR format: prompt + accepted_completion
234
  def format_openhermes_fr(example):
@@ -317,6 +367,72 @@ def process_dataset_format(dataset, config):
317
 
318
  return dataset
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  def setup_trackio_tracking(config):
321
  """Setup Trackio tracking if enabled"""
322
 
@@ -530,6 +646,9 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
530
 
531
  # Load dataset
532
  dataset = load_dataset_from_config(config)
 
 
 
533
 
534
  # Setup Trackio tracking
535
  trackio_client = setup_trackio_tracking(config)
@@ -538,37 +657,78 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
538
  sft_config = create_sft_config(config, output_dir)
539
 
540
  # Create trainer with version-robust kwargs
541
- print("Creating SFT trainer...")
542
- try:
543
- sft_sig = inspect.signature(SFTTrainer.__init__)
544
- sft_params = set(sft_sig.parameters.keys())
545
- except Exception:
546
- sft_params = {"model", "args", "train_dataset", "tokenizer", "dataset_text_field", "max_seq_length"}
547
 
548
- sft_kwargs = {
549
- "model": peft_model,
550
- "args": sft_config,
551
- "train_dataset": dataset,
552
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
 
554
- # Prefer passing tokenizer if supported; otherwise try processing_class
555
- if "tokenizer" in sft_params:
556
- sft_kwargs["tokenizer"] = tokenizer
557
- elif "processing_class" in sft_params:
558
- sft_kwargs["processing_class"] = tokenizer
559
 
560
- # Pass dataset text field if supported (we produced a 'text' column)
561
- if "dataset_text_field" in sft_params:
562
- sft_kwargs["dataset_text_field"] = "text"
563
 
564
- # Pass max sequence length if supported
565
- if "max_seq_length" in sft_params:
566
- sft_kwargs["max_seq_length"] = getattr(config, 'max_seq_length', 2048)
567
 
568
- # Remove any None values
569
- sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None}
570
 
571
- trainer = SFTTrainer(**sft_kwargs)
 
 
 
572
 
573
  # Start training
574
  print("Starting GPT-OSS training...")
 
13
  from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
14
  from peft import LoraConfig, get_peft_model
15
  from trl import SFTTrainer
16
+ try:
17
+ from trl import DPOTrainer
18
+ except Exception: # pragma: no cover - optional import depending on TRL version
19
+ DPOTrainer = None
20
  from datasets import load_dataset
21
  from pathlib import Path
22
 
 
218
 
219
  return harmony_text
220
 
221
+ def format_gpt_oss_harmony_prompt(prompt: str) -> str:
222
+ """Prefix-only Harmony prompt up to assistant content marker for DPO."""
223
+ return f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>"
224
+
225
  def process_dataset_format(dataset, config):
226
  """Process dataset based on format configuration with exact GPT-OSS Harmony compliance"""
227
 
 
232
  field_separator = getattr(config, 'field_separator', '\n\n### Response:\n')
233
  add_eos_token = getattr(config, 'add_eos_token', True)
234
  use_harmony_format = getattr(config, 'use_harmony_format', True)
235
+ trainer_type = getattr(config, 'trainer_type', 'sft')
236
 
237
  print(f"Processing dataset format: {dataset_format}")
238
  print(f"Input field: {input_field}, Target field: {target_field}")
239
  print(f"GPT-OSS Harmony Format: {'Enabled' if use_harmony_format else 'Disabled'}")
240
 
241
+ # Preference-format for DPO training (chosen/rejected pairs)
242
+ if trainer_type == 'dpo':
243
+ chosen_field = getattr(config, 'chosen_field', None)
244
+ rejected_field = getattr(config, 'rejected_field', None)
245
+
246
+ if dataset_format == 'preference':
247
+ # Expect columns present; optionally reformat to ensure only necessary columns
248
+ def id_map(example):
249
+ prompt_val = example.get(input_field, '')
250
+ chosen_val = example.get('chosen', example.get(chosen_field or 'chosen', ''))
251
+ rejected_val = example.get('rejected', example.get(rejected_field or 'rejected', ''))
252
+ if use_harmony_format:
253
+ prompt_text = format_gpt_oss_harmony_prompt(prompt_val)
254
+ chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '')
255
+ rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '')
256
+ return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
257
+ return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val}
258
+
259
+ keep_cols = [c for c in ['prompt', 'chosen', 'rejected'] if c in dataset.column_names]
260
+ dataset = dataset.map(id_map, remove_columns=dataset.column_names if keep_cols else dataset.column_names)
261
+ return dataset
262
+
263
+ # Custom preference mapping via configured field names
264
+ if chosen_field and rejected_field:
265
+ def to_pref(example):
266
+ prompt_val = example.get(input_field, '')
267
+ chosen_val = example.get(chosen_field, '')
268
+ rejected_val = example.get(rejected_field, '')
269
+ if use_harmony_format:
270
+ prompt_text = format_gpt_oss_harmony_prompt(prompt_val)
271
+ chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '')
272
+ rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '')
273
+ return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
274
+ return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val}
275
+
276
+ dataset = dataset.map(to_pref, remove_columns=dataset.column_names)
277
+ return dataset
278
+
279
+ # If we reach here, we don't have required fields for DPO
280
+ raise ValueError("DPO training requires preference data. Please set dataset_format='preference' with 'prompt', 'chosen', 'rejected' columns, or specify 'chosen_field' and 'rejected_field' in the config.")
281
+
282
  if dataset_format == "openhermes_fr":
283
  # Process OpenHermes-FR format: prompt + accepted_completion
284
  def format_openhermes_fr(example):
 
367
 
368
  return dataset
369
 
370
+ def split_dataset(dataset, config):
371
+ """Create train/validation/test splits from a single dataset.
372
+ Defaults to 1% eval and 1% test if not specified.
373
+ """
374
+ from datasets import Dataset
375
+
376
+ if not isinstance(dataset, Dataset):
377
+ # If it's already a DatasetDict, try to use its splits
378
+ try:
379
+ train_split = dataset["train"]
380
+ eval_split = dataset.get("validation") or dataset.get("eval")
381
+ test_split = dataset.get("test")
382
+ return train_split, eval_split, test_split
383
+ except Exception:
384
+ pass
385
+
386
+ eval_ratio = getattr(config, 'eval_ratio', 0.01)
387
+ test_ratio = getattr(config, 'test_ratio', 0.01)
388
+
389
+ # Clamp ratios to sane bounds
390
+ try:
391
+ eval_ratio = max(0.0, float(eval_ratio))
392
+ test_ratio = max(0.0, float(test_ratio))
393
+ if eval_ratio + test_ratio >= 0.9:
394
+ # Avoid extreme splits; cap combined at 0.2
395
+ scale = 0.2 / max(1e-9, (eval_ratio + test_ratio))
396
+ eval_ratio *= scale
397
+ test_ratio *= scale
398
+ except Exception:
399
+ eval_ratio, test_ratio = 0.01, 0.01
400
+
401
+ # No eval/test requested
402
+ if eval_ratio <= 0 and test_ratio <= 0:
403
+ return dataset, None, None
404
+
405
+ ds_shuffled = dataset.shuffle(seed=42)
406
+
407
+ # First carve out test split
408
+ if test_ratio > 0:
409
+ split1 = ds_shuffled.train_test_split(test_size=test_ratio, seed=42)
410
+ train_part = split1["train"]
411
+ test_split = split1["test"]
412
+ else:
413
+ train_part = ds_shuffled
414
+ test_split = None
415
+
416
+ # Then carve out eval from remaining train
417
+ if eval_ratio > 0:
418
+ remaining_fraction = 1.0 - test_ratio
419
+ # Convert global eval fraction to fraction of remaining pool
420
+ relative_eval = eval_ratio / remaining_fraction if remaining_fraction > 0 else eval_ratio
421
+ split2 = train_part.train_test_split(test_size=relative_eval, seed=42)
422
+ train_split = split2["train"]
423
+ eval_split = split2["test"]
424
+ else:
425
+ train_split = train_part
426
+ eval_split = None
427
+
428
+ # Log sizes
429
+ try:
430
+ print(f"Created splits -> train: {len(train_split)}, eval: {len(eval_split) if eval_split else 0}, test: {len(test_split) if test_split else 0}")
431
+ except Exception:
432
+ pass
433
+
434
+ return train_split, eval_split, test_split
435
+
436
  def setup_trackio_tracking(config):
437
  """Setup Trackio tracking if enabled"""
438
 
 
646
 
647
  # Load dataset
648
  dataset = load_dataset_from_config(config)
649
+
650
+ # Split into train/eval/test
651
+ train_dataset, eval_dataset, test_dataset = split_dataset(dataset, config)
652
 
653
  # Setup Trackio tracking
654
  trackio_client = setup_trackio_tracking(config)
 
657
  sft_config = create_sft_config(config, output_dir)
658
 
659
  # Create trainer with version-robust kwargs
660
+ if trainer_type == 'dpo':
661
+ if DPOTrainer is None:
662
+ raise RuntimeError("DPOTrainer is not available in this TRL version. Please upgrade 'trl'.")
 
 
 
663
 
664
+ print("Creating DPO trainer...")
665
+ try:
666
+ dpo_sig = inspect.signature(DPOTrainer.__init__)
667
+ dpo_params = set(dpo_sig.parameters.keys())
668
+ except Exception:
669
+ dpo_params = {"model", "args", "train_dataset", "tokenizer", "beta", "prompt_column", "chosen_column", "rejected_column"}
670
+
671
+ dpo_kwargs = {
672
+ "model": peft_model,
673
+ "args": sft_config,
674
+ "train_dataset": train_dataset,
675
+ "beta": getattr(config, 'dpo_beta', 0.1),
676
+ }
677
+
678
+ if "tokenizer" in dpo_params:
679
+ dpo_kwargs["tokenizer"] = tokenizer
680
+ elif "processing_class" in dpo_params:
681
+ dpo_kwargs["processing_class"] = tokenizer
682
+
683
+ if "prompt_column" in dpo_params:
684
+ dpo_kwargs["prompt_column"] = "prompt"
685
+ if "chosen_column" in dpo_params:
686
+ dpo_kwargs["chosen_column"] = "chosen"
687
+ if "rejected_column" in dpo_params:
688
+ dpo_kwargs["rejected_column"] = "rejected"
689
+
690
+ # Remove Nones
691
+ dpo_kwargs = {k: v for k, v in dpo_kwargs.items() if v is not None}
692
+
693
+ # Pass eval dataset if supported
694
+ if "eval_dataset" in dpo_params and eval_dataset is not None:
695
+ dpo_kwargs["eval_dataset"] = eval_dataset
696
+ trainer = DPOTrainer(**dpo_kwargs)
697
+ else:
698
+ print("Creating SFT trainer...")
699
+ try:
700
+ sft_sig = inspect.signature(SFTTrainer.__init__)
701
+ sft_params = set(sft_sig.parameters.keys())
702
+ except Exception:
703
+ sft_params = {"model", "args", "train_dataset", "tokenizer", "dataset_text_field", "max_seq_length"}
704
+
705
+ sft_kwargs = {
706
+ "model": peft_model,
707
+ "args": sft_config,
708
+ "train_dataset": train_dataset,
709
+ }
710
 
711
+ # Prefer passing tokenizer if supported; otherwise try processing_class
712
+ if "tokenizer" in sft_params:
713
+ sft_kwargs["tokenizer"] = tokenizer
714
+ elif "processing_class" in sft_params:
715
+ sft_kwargs["processing_class"] = tokenizer
716
 
717
+ # Pass dataset text field if supported (we produced a 'text' column)
718
+ if "dataset_text_field" in sft_params:
719
+ sft_kwargs["dataset_text_field"] = "text"
720
 
721
+ # Pass max sequence length if supported
722
+ if "max_seq_length" in sft_params:
723
+ sft_kwargs["max_seq_length"] = getattr(config, 'max_seq_length', 2048)
724
 
725
+ # Remove any None values
726
+ sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None}
727
 
728
+ # Attach eval_dataset if supported
729
+ if "eval_dataset" in sft_params and eval_dataset is not None:
730
+ sft_kwargs["eval_dataset"] = eval_dataset
731
+ trainer = SFTTrainer(**sft_kwargs)
732
 
733
  # Start training
734
  print("Starting GPT-OSS training...")