Spaces:
Running
Running
adds better launch.sh and eval / test splits auto
Browse files- config/train_gpt_oss_basic.py +10 -0
- config/train_gpt_oss_custom.py +8 -0
- config/train_gpt_oss_h100_optimized.py +7 -0
- config/train_gpt_oss_memory_optimized.py +8 -0
- config/train_gpt_oss_multilingual_reasoning.py +8 -0
- config/train_gpt_oss_openhermes_fr.py +3 -0
- config/train_gpt_oss_openhermes_fr_memory_optimized.py +3 -0
- launch.sh +9 -1
- scripts/training/train_gpt_oss.py +185 -25
config/train_gpt_oss_basic.py
CHANGED
@@ -62,6 +62,9 @@ class GPTOSSBasicConfig:
|
|
62 |
metric_for_best_model: str = "eval_loss"
|
63 |
greater_is_better: bool = False
|
64 |
load_best_model_at_end: bool = True
|
|
|
|
|
|
|
65 |
|
66 |
# Data configuration
|
67 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
@@ -99,6 +102,13 @@ class GPTOSSBasicConfig:
|
|
99 |
|
100 |
# GPT-OSS specific model kwargs
|
101 |
model_kwargs: dict = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
def __post_init__(self):
|
104 |
if self.chat_template_kwargs is None:
|
|
|
62 |
metric_for_best_model: str = "eval_loss"
|
63 |
greater_is_better: bool = False
|
64 |
load_best_model_at_end: bool = True
|
65 |
+
eval_accumulation_steps: Optional[int] = None
|
66 |
+
eval_ratio: float = 0.01
|
67 |
+
test_ratio: float = 0.01
|
68 |
|
69 |
# Data configuration
|
70 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
|
|
102 |
|
103 |
# GPT-OSS specific model kwargs
|
104 |
model_kwargs: dict = None
|
105 |
+
# Performance and precision extras
|
106 |
+
dataloader_prefetch_factor: int = 2
|
107 |
+
tf32: Optional[bool] = None
|
108 |
+
# DPO preference training fields
|
109 |
+
chosen_field: Optional[str] = None
|
110 |
+
rejected_field: Optional[str] = None
|
111 |
+
dpo_beta: float = 0.1
|
112 |
|
113 |
def __post_init__(self):
|
114 |
if self.chat_template_kwargs is None:
|
config/train_gpt_oss_custom.py
CHANGED
@@ -83,6 +83,9 @@ class GPTOSSEnhancedCustomConfig:
|
|
83 |
eval_steps: int = 100 # Evaluate every N steps
|
84 |
eval_delay: float = 0 # Delay evaluation for N steps/epochs
|
85 |
eval_accumulation_steps: Optional[int] = None # Accumulate eval outputs
|
|
|
|
|
|
|
86 |
|
87 |
# Checkpointing
|
88 |
save_strategy: str = "steps" # "no", "steps", "epoch"
|
@@ -167,6 +170,11 @@ class GPTOSSEnhancedCustomConfig:
|
|
167 |
|
168 |
# Generation Configuration (for evaluation/testing)
|
169 |
generation_config: Optional[Dict] = None
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
# ============================================================================
|
172 |
# MULTILINGUAL & DOMAIN SPECIFIC SETTINGS
|
|
|
83 |
eval_steps: int = 100 # Evaluate every N steps
|
84 |
eval_delay: float = 0 # Delay evaluation for N steps/epochs
|
85 |
eval_accumulation_steps: Optional[int] = None # Accumulate eval outputs
|
86 |
+
# Automatic split ratios when only a single training split is provided
|
87 |
+
eval_ratio: float = 0.01 # Fraction of data for validation (0.0-0.5 typical)
|
88 |
+
test_ratio: float = 0.01 # Fraction of data for test (0.0-0.5 typical)
|
89 |
|
90 |
# Checkpointing
|
91 |
save_strategy: str = "steps" # "no", "steps", "epoch"
|
|
|
170 |
|
171 |
# Generation Configuration (for evaluation/testing)
|
172 |
generation_config: Optional[Dict] = None
|
173 |
+
|
174 |
+
# Preference-training (DPO) configuration
|
175 |
+
chosen_field: Optional[str] = None # Field name for preferred response (for DPO datasets)
|
176 |
+
rejected_field: Optional[str] = None # Field name for rejected response (for DPO datasets)
|
177 |
+
dpo_beta: float = 0.1 # DPO beta parameter
|
178 |
|
179 |
# ============================================================================
|
180 |
# MULTILINGUAL & DOMAIN SPECIFIC SETTINGS
|
config/train_gpt_oss_h100_optimized.py
CHANGED
@@ -62,6 +62,9 @@ class GPTOSSH100OptimizedConfig:
|
|
62 |
metric_for_best_model: str = "eval_loss"
|
63 |
greater_is_better: bool = False
|
64 |
load_best_model_at_end: bool = True
|
|
|
|
|
|
|
65 |
|
66 |
# Data configuration
|
67 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
@@ -104,6 +107,10 @@ class GPTOSSH100OptimizedConfig:
|
|
104 |
dataloader_num_workers: int = 8 # More workers for H100
|
105 |
dataloader_pin_memory: bool = True
|
106 |
dataloader_prefetch_factor: int = 4 # Increased prefetch
|
|
|
|
|
|
|
|
|
107 |
|
108 |
# Memory optimizations for H100
|
109 |
max_grad_norm: float = 1.0
|
|
|
62 |
metric_for_best_model: str = "eval_loss"
|
63 |
greater_is_better: bool = False
|
64 |
load_best_model_at_end: bool = True
|
65 |
+
eval_accumulation_steps: Optional[int] = None
|
66 |
+
eval_ratio: float = 0.01
|
67 |
+
test_ratio: float = 0.01
|
68 |
|
69 |
# Data configuration
|
70 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
|
|
107 |
dataloader_num_workers: int = 8 # More workers for H100
|
108 |
dataloader_pin_memory: bool = True
|
109 |
dataloader_prefetch_factor: int = 4 # Increased prefetch
|
110 |
+
tf32: Optional[bool] = None
|
111 |
+
chosen_field: Optional[str] = None
|
112 |
+
rejected_field: Optional[str] = None
|
113 |
+
dpo_beta: float = 0.1
|
114 |
|
115 |
# Memory optimizations for H100
|
116 |
max_grad_norm: float = 1.0
|
config/train_gpt_oss_memory_optimized.py
CHANGED
@@ -43,6 +43,9 @@ class GPTOSSMemoryOptimizedConfig:
|
|
43 |
metric_for_best_model: str = "eval_loss"
|
44 |
greater_is_better: bool = False
|
45 |
load_best_model_at_end: bool = True
|
|
|
|
|
|
|
46 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
47 |
dataset_split: str = "train"
|
48 |
input_field: str = "messages"
|
@@ -65,6 +68,11 @@ class GPTOSSMemoryOptimizedConfig:
|
|
65 |
use_quantization: bool = True
|
66 |
quantization_config: dict = None
|
67 |
model_kwargs: dict = None
|
|
|
|
|
|
|
|
|
|
|
68 |
generation_config: dict = None
|
69 |
reasoning_languages: list = None
|
70 |
|
|
|
43 |
metric_for_best_model: str = "eval_loss"
|
44 |
greater_is_better: bool = False
|
45 |
load_best_model_at_end: bool = True
|
46 |
+
eval_accumulation_steps: Optional[int] = None
|
47 |
+
eval_ratio: float = 0.01
|
48 |
+
test_ratio: float = 0.01
|
49 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
50 |
dataset_split: str = "train"
|
51 |
input_field: str = "messages"
|
|
|
68 |
use_quantization: bool = True
|
69 |
quantization_config: dict = None
|
70 |
model_kwargs: dict = None
|
71 |
+
dataloader_prefetch_factor: int = 2
|
72 |
+
tf32: Optional[bool] = None
|
73 |
+
chosen_field: Optional[str] = None
|
74 |
+
rejected_field: Optional[str] = None
|
75 |
+
dpo_beta: float = 0.1
|
76 |
generation_config: dict = None
|
77 |
reasoning_languages: list = None
|
78 |
|
config/train_gpt_oss_multilingual_reasoning.py
CHANGED
@@ -62,6 +62,9 @@ class GPTOSSMultilingualReasoningConfig:
|
|
62 |
metric_for_best_model: str = "eval_loss"
|
63 |
greater_is_better: bool = False
|
64 |
load_best_model_at_end: bool = True
|
|
|
|
|
|
|
65 |
|
66 |
# Data configuration - Multilingual-Thinking specific
|
67 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
@@ -99,6 +102,11 @@ class GPTOSSMultilingualReasoningConfig:
|
|
99 |
|
100 |
# GPT-OSS specific model kwargs - as per tutorial
|
101 |
model_kwargs: dict = None
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
# Multilingual reasoning specific configurations
|
104 |
# Generation parameters for multilingual reasoning
|
|
|
62 |
metric_for_best_model: str = "eval_loss"
|
63 |
greater_is_better: bool = False
|
64 |
load_best_model_at_end: bool = True
|
65 |
+
eval_accumulation_steps: Optional[int] = None
|
66 |
+
eval_ratio: float = 0.01
|
67 |
+
test_ratio: float = 0.01
|
68 |
|
69 |
# Data configuration - Multilingual-Thinking specific
|
70 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
|
|
102 |
|
103 |
# GPT-OSS specific model kwargs - as per tutorial
|
104 |
model_kwargs: dict = None
|
105 |
+
dataloader_prefetch_factor: int = 2
|
106 |
+
tf32: Optional[bool] = None
|
107 |
+
chosen_field: Optional[str] = None
|
108 |
+
rejected_field: Optional[str] = None
|
109 |
+
dpo_beta: float = 0.1
|
110 |
|
111 |
# Multilingual reasoning specific configurations
|
112 |
# Generation parameters for multilingual reasoning
|
config/train_gpt_oss_openhermes_fr.py
CHANGED
@@ -119,6 +119,9 @@ config = GPTOSSEnhancedCustomConfig(
|
|
119 |
metric_for_best_model="eval_loss",
|
120 |
greater_is_better=False,
|
121 |
load_best_model_at_end=True,
|
|
|
|
|
|
|
122 |
|
123 |
# ============================================================================
|
124 |
# MULTILINGUAL & FRENCH SPECIFIC SETTINGS
|
|
|
119 |
metric_for_best_model="eval_loss",
|
120 |
greater_is_better=False,
|
121 |
load_best_model_at_end=True,
|
122 |
+
# Split ratios for automatic validation/test creation
|
123 |
+
eval_ratio=0.01,
|
124 |
+
test_ratio=0.01,
|
125 |
|
126 |
# ============================================================================
|
127 |
# MULTILINGUAL & FRENCH SPECIFIC SETTINGS
|
config/train_gpt_oss_openhermes_fr_memory_optimized.py
CHANGED
@@ -144,6 +144,9 @@ config = GPTOSSEnhancedCustomConfig(
|
|
144 |
# Evaluation memory optimization
|
145 |
eval_accumulation_steps=4, # Accumulate eval outputs to save memory
|
146 |
eval_batch_size=1, # Smaller eval batch size
|
|
|
|
|
|
|
147 |
|
148 |
# ============================================================================
|
149 |
# GPT-OSS HARMONY FORMAT OPTIMIZATION
|
|
|
144 |
# Evaluation memory optimization
|
145 |
eval_accumulation_steps=4, # Accumulate eval outputs to save memory
|
146 |
eval_batch_size=1, # Smaller eval batch size
|
147 |
+
# Split ratios for automatic validation/test creation
|
148 |
+
eval_ratio=0.001,
|
149 |
+
test_ratio=0.0005,
|
150 |
|
151 |
# ============================================================================
|
152 |
# GPT-OSS HARMONY FORMAT OPTIMIZATION
|
launch.sh
CHANGED
@@ -827,7 +827,15 @@ fi
|
|
827 |
print_step "Step 3: Experiment Details"
|
828 |
echo "=============================="
|
829 |
|
830 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
831 |
|
832 |
# Configure model repository name (customizable)
|
833 |
print_info "Setting up model repository name..."
|
|
|
827 |
print_step "Step 3: Experiment Details"
|
828 |
echo "=============================="
|
829 |
|
830 |
+
# Derive default experiment name from smolfactory + chosen model family
|
831 |
+
if [ "$MODEL_FAMILY" = "GPT-OSS" ]; then
|
832 |
+
FAMILY_SLUG="gpt-oss"
|
833 |
+
else
|
834 |
+
FAMILY_SLUG="smollm3"
|
835 |
+
fi
|
836 |
+
DEFAULT_EXPERIMENT_NAME="smolfactory-${FAMILY_SLUG}_$(date +%Y%m%d_%H%M%S)"
|
837 |
+
|
838 |
+
get_input "Experiment name" "$DEFAULT_EXPERIMENT_NAME" EXPERIMENT_NAME
|
839 |
|
840 |
# Configure model repository name (customizable)
|
841 |
print_info "Setting up model repository name..."
|
scripts/training/train_gpt_oss.py
CHANGED
@@ -13,6 +13,10 @@ import torch
|
|
13 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
14 |
from peft import LoraConfig, get_peft_model
|
15 |
from trl import SFTTrainer
|
|
|
|
|
|
|
|
|
16 |
from datasets import load_dataset
|
17 |
from pathlib import Path
|
18 |
|
@@ -214,6 +218,10 @@ def format_gpt_oss_harmony(prompt, completion, add_eos_token=True):
|
|
214 |
|
215 |
return harmony_text
|
216 |
|
|
|
|
|
|
|
|
|
217 |
def process_dataset_format(dataset, config):
|
218 |
"""Process dataset based on format configuration with exact GPT-OSS Harmony compliance"""
|
219 |
|
@@ -224,11 +232,53 @@ def process_dataset_format(dataset, config):
|
|
224 |
field_separator = getattr(config, 'field_separator', '\n\n### Response:\n')
|
225 |
add_eos_token = getattr(config, 'add_eos_token', True)
|
226 |
use_harmony_format = getattr(config, 'use_harmony_format', True)
|
|
|
227 |
|
228 |
print(f"Processing dataset format: {dataset_format}")
|
229 |
print(f"Input field: {input_field}, Target field: {target_field}")
|
230 |
print(f"GPT-OSS Harmony Format: {'Enabled' if use_harmony_format else 'Disabled'}")
|
231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
if dataset_format == "openhermes_fr":
|
233 |
# Process OpenHermes-FR format: prompt + accepted_completion
|
234 |
def format_openhermes_fr(example):
|
@@ -317,6 +367,72 @@ def process_dataset_format(dataset, config):
|
|
317 |
|
318 |
return dataset
|
319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
def setup_trackio_tracking(config):
|
321 |
"""Setup Trackio tracking if enabled"""
|
322 |
|
@@ -530,6 +646,9 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
|
|
530 |
|
531 |
# Load dataset
|
532 |
dataset = load_dataset_from_config(config)
|
|
|
|
|
|
|
533 |
|
534 |
# Setup Trackio tracking
|
535 |
trackio_client = setup_trackio_tracking(config)
|
@@ -538,37 +657,78 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
|
|
538 |
sft_config = create_sft_config(config, output_dir)
|
539 |
|
540 |
# Create trainer with version-robust kwargs
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
sft_params = set(sft_sig.parameters.keys())
|
545 |
-
except Exception:
|
546 |
-
sft_params = {"model", "args", "train_dataset", "tokenizer", "dataset_text_field", "max_seq_length"}
|
547 |
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
|
568 |
-
|
569 |
-
|
570 |
|
571 |
-
|
|
|
|
|
|
|
572 |
|
573 |
# Start training
|
574 |
print("Starting GPT-OSS training...")
|
|
|
13 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
14 |
from peft import LoraConfig, get_peft_model
|
15 |
from trl import SFTTrainer
|
16 |
+
try:
|
17 |
+
from trl import DPOTrainer
|
18 |
+
except Exception: # pragma: no cover - optional import depending on TRL version
|
19 |
+
DPOTrainer = None
|
20 |
from datasets import load_dataset
|
21 |
from pathlib import Path
|
22 |
|
|
|
218 |
|
219 |
return harmony_text
|
220 |
|
221 |
+
def format_gpt_oss_harmony_prompt(prompt: str) -> str:
|
222 |
+
"""Prefix-only Harmony prompt up to assistant content marker for DPO."""
|
223 |
+
return f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>"
|
224 |
+
|
225 |
def process_dataset_format(dataset, config):
|
226 |
"""Process dataset based on format configuration with exact GPT-OSS Harmony compliance"""
|
227 |
|
|
|
232 |
field_separator = getattr(config, 'field_separator', '\n\n### Response:\n')
|
233 |
add_eos_token = getattr(config, 'add_eos_token', True)
|
234 |
use_harmony_format = getattr(config, 'use_harmony_format', True)
|
235 |
+
trainer_type = getattr(config, 'trainer_type', 'sft')
|
236 |
|
237 |
print(f"Processing dataset format: {dataset_format}")
|
238 |
print(f"Input field: {input_field}, Target field: {target_field}")
|
239 |
print(f"GPT-OSS Harmony Format: {'Enabled' if use_harmony_format else 'Disabled'}")
|
240 |
|
241 |
+
# Preference-format for DPO training (chosen/rejected pairs)
|
242 |
+
if trainer_type == 'dpo':
|
243 |
+
chosen_field = getattr(config, 'chosen_field', None)
|
244 |
+
rejected_field = getattr(config, 'rejected_field', None)
|
245 |
+
|
246 |
+
if dataset_format == 'preference':
|
247 |
+
# Expect columns present; optionally reformat to ensure only necessary columns
|
248 |
+
def id_map(example):
|
249 |
+
prompt_val = example.get(input_field, '')
|
250 |
+
chosen_val = example.get('chosen', example.get(chosen_field or 'chosen', ''))
|
251 |
+
rejected_val = example.get('rejected', example.get(rejected_field or 'rejected', ''))
|
252 |
+
if use_harmony_format:
|
253 |
+
prompt_text = format_gpt_oss_harmony_prompt(prompt_val)
|
254 |
+
chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '')
|
255 |
+
rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '')
|
256 |
+
return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
|
257 |
+
return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val}
|
258 |
+
|
259 |
+
keep_cols = [c for c in ['prompt', 'chosen', 'rejected'] if c in dataset.column_names]
|
260 |
+
dataset = dataset.map(id_map, remove_columns=dataset.column_names if keep_cols else dataset.column_names)
|
261 |
+
return dataset
|
262 |
+
|
263 |
+
# Custom preference mapping via configured field names
|
264 |
+
if chosen_field and rejected_field:
|
265 |
+
def to_pref(example):
|
266 |
+
prompt_val = example.get(input_field, '')
|
267 |
+
chosen_val = example.get(chosen_field, '')
|
268 |
+
rejected_val = example.get(rejected_field, '')
|
269 |
+
if use_harmony_format:
|
270 |
+
prompt_text = format_gpt_oss_harmony_prompt(prompt_val)
|
271 |
+
chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '')
|
272 |
+
rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '')
|
273 |
+
return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
|
274 |
+
return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val}
|
275 |
+
|
276 |
+
dataset = dataset.map(to_pref, remove_columns=dataset.column_names)
|
277 |
+
return dataset
|
278 |
+
|
279 |
+
# If we reach here, we don't have required fields for DPO
|
280 |
+
raise ValueError("DPO training requires preference data. Please set dataset_format='preference' with 'prompt', 'chosen', 'rejected' columns, or specify 'chosen_field' and 'rejected_field' in the config.")
|
281 |
+
|
282 |
if dataset_format == "openhermes_fr":
|
283 |
# Process OpenHermes-FR format: prompt + accepted_completion
|
284 |
def format_openhermes_fr(example):
|
|
|
367 |
|
368 |
return dataset
|
369 |
|
370 |
+
def split_dataset(dataset, config):
|
371 |
+
"""Create train/validation/test splits from a single dataset.
|
372 |
+
Defaults to 1% eval and 1% test if not specified.
|
373 |
+
"""
|
374 |
+
from datasets import Dataset
|
375 |
+
|
376 |
+
if not isinstance(dataset, Dataset):
|
377 |
+
# If it's already a DatasetDict, try to use its splits
|
378 |
+
try:
|
379 |
+
train_split = dataset["train"]
|
380 |
+
eval_split = dataset.get("validation") or dataset.get("eval")
|
381 |
+
test_split = dataset.get("test")
|
382 |
+
return train_split, eval_split, test_split
|
383 |
+
except Exception:
|
384 |
+
pass
|
385 |
+
|
386 |
+
eval_ratio = getattr(config, 'eval_ratio', 0.01)
|
387 |
+
test_ratio = getattr(config, 'test_ratio', 0.01)
|
388 |
+
|
389 |
+
# Clamp ratios to sane bounds
|
390 |
+
try:
|
391 |
+
eval_ratio = max(0.0, float(eval_ratio))
|
392 |
+
test_ratio = max(0.0, float(test_ratio))
|
393 |
+
if eval_ratio + test_ratio >= 0.9:
|
394 |
+
# Avoid extreme splits; cap combined at 0.2
|
395 |
+
scale = 0.2 / max(1e-9, (eval_ratio + test_ratio))
|
396 |
+
eval_ratio *= scale
|
397 |
+
test_ratio *= scale
|
398 |
+
except Exception:
|
399 |
+
eval_ratio, test_ratio = 0.01, 0.01
|
400 |
+
|
401 |
+
# No eval/test requested
|
402 |
+
if eval_ratio <= 0 and test_ratio <= 0:
|
403 |
+
return dataset, None, None
|
404 |
+
|
405 |
+
ds_shuffled = dataset.shuffle(seed=42)
|
406 |
+
|
407 |
+
# First carve out test split
|
408 |
+
if test_ratio > 0:
|
409 |
+
split1 = ds_shuffled.train_test_split(test_size=test_ratio, seed=42)
|
410 |
+
train_part = split1["train"]
|
411 |
+
test_split = split1["test"]
|
412 |
+
else:
|
413 |
+
train_part = ds_shuffled
|
414 |
+
test_split = None
|
415 |
+
|
416 |
+
# Then carve out eval from remaining train
|
417 |
+
if eval_ratio > 0:
|
418 |
+
remaining_fraction = 1.0 - test_ratio
|
419 |
+
# Convert global eval fraction to fraction of remaining pool
|
420 |
+
relative_eval = eval_ratio / remaining_fraction if remaining_fraction > 0 else eval_ratio
|
421 |
+
split2 = train_part.train_test_split(test_size=relative_eval, seed=42)
|
422 |
+
train_split = split2["train"]
|
423 |
+
eval_split = split2["test"]
|
424 |
+
else:
|
425 |
+
train_split = train_part
|
426 |
+
eval_split = None
|
427 |
+
|
428 |
+
# Log sizes
|
429 |
+
try:
|
430 |
+
print(f"Created splits -> train: {len(train_split)}, eval: {len(eval_split) if eval_split else 0}, test: {len(test_split) if test_split else 0}")
|
431 |
+
except Exception:
|
432 |
+
pass
|
433 |
+
|
434 |
+
return train_split, eval_split, test_split
|
435 |
+
|
436 |
def setup_trackio_tracking(config):
|
437 |
"""Setup Trackio tracking if enabled"""
|
438 |
|
|
|
646 |
|
647 |
# Load dataset
|
648 |
dataset = load_dataset_from_config(config)
|
649 |
+
|
650 |
+
# Split into train/eval/test
|
651 |
+
train_dataset, eval_dataset, test_dataset = split_dataset(dataset, config)
|
652 |
|
653 |
# Setup Trackio tracking
|
654 |
trackio_client = setup_trackio_tracking(config)
|
|
|
657 |
sft_config = create_sft_config(config, output_dir)
|
658 |
|
659 |
# Create trainer with version-robust kwargs
|
660 |
+
if trainer_type == 'dpo':
|
661 |
+
if DPOTrainer is None:
|
662 |
+
raise RuntimeError("DPOTrainer is not available in this TRL version. Please upgrade 'trl'.")
|
|
|
|
|
|
|
663 |
|
664 |
+
print("Creating DPO trainer...")
|
665 |
+
try:
|
666 |
+
dpo_sig = inspect.signature(DPOTrainer.__init__)
|
667 |
+
dpo_params = set(dpo_sig.parameters.keys())
|
668 |
+
except Exception:
|
669 |
+
dpo_params = {"model", "args", "train_dataset", "tokenizer", "beta", "prompt_column", "chosen_column", "rejected_column"}
|
670 |
+
|
671 |
+
dpo_kwargs = {
|
672 |
+
"model": peft_model,
|
673 |
+
"args": sft_config,
|
674 |
+
"train_dataset": train_dataset,
|
675 |
+
"beta": getattr(config, 'dpo_beta', 0.1),
|
676 |
+
}
|
677 |
+
|
678 |
+
if "tokenizer" in dpo_params:
|
679 |
+
dpo_kwargs["tokenizer"] = tokenizer
|
680 |
+
elif "processing_class" in dpo_params:
|
681 |
+
dpo_kwargs["processing_class"] = tokenizer
|
682 |
+
|
683 |
+
if "prompt_column" in dpo_params:
|
684 |
+
dpo_kwargs["prompt_column"] = "prompt"
|
685 |
+
if "chosen_column" in dpo_params:
|
686 |
+
dpo_kwargs["chosen_column"] = "chosen"
|
687 |
+
if "rejected_column" in dpo_params:
|
688 |
+
dpo_kwargs["rejected_column"] = "rejected"
|
689 |
+
|
690 |
+
# Remove Nones
|
691 |
+
dpo_kwargs = {k: v for k, v in dpo_kwargs.items() if v is not None}
|
692 |
+
|
693 |
+
# Pass eval dataset if supported
|
694 |
+
if "eval_dataset" in dpo_params and eval_dataset is not None:
|
695 |
+
dpo_kwargs["eval_dataset"] = eval_dataset
|
696 |
+
trainer = DPOTrainer(**dpo_kwargs)
|
697 |
+
else:
|
698 |
+
print("Creating SFT trainer...")
|
699 |
+
try:
|
700 |
+
sft_sig = inspect.signature(SFTTrainer.__init__)
|
701 |
+
sft_params = set(sft_sig.parameters.keys())
|
702 |
+
except Exception:
|
703 |
+
sft_params = {"model", "args", "train_dataset", "tokenizer", "dataset_text_field", "max_seq_length"}
|
704 |
+
|
705 |
+
sft_kwargs = {
|
706 |
+
"model": peft_model,
|
707 |
+
"args": sft_config,
|
708 |
+
"train_dataset": train_dataset,
|
709 |
+
}
|
710 |
|
711 |
+
# Prefer passing tokenizer if supported; otherwise try processing_class
|
712 |
+
if "tokenizer" in sft_params:
|
713 |
+
sft_kwargs["tokenizer"] = tokenizer
|
714 |
+
elif "processing_class" in sft_params:
|
715 |
+
sft_kwargs["processing_class"] = tokenizer
|
716 |
|
717 |
+
# Pass dataset text field if supported (we produced a 'text' column)
|
718 |
+
if "dataset_text_field" in sft_params:
|
719 |
+
sft_kwargs["dataset_text_field"] = "text"
|
720 |
|
721 |
+
# Pass max sequence length if supported
|
722 |
+
if "max_seq_length" in sft_params:
|
723 |
+
sft_kwargs["max_seq_length"] = getattr(config, 'max_seq_length', 2048)
|
724 |
|
725 |
+
# Remove any None values
|
726 |
+
sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None}
|
727 |
|
728 |
+
# Attach eval_dataset if supported
|
729 |
+
if "eval_dataset" in sft_params and eval_dataset is not None:
|
730 |
+
sft_kwargs["eval_dataset"] = eval_dataset
|
731 |
+
trainer = SFTTrainer(**sft_kwargs)
|
732 |
|
733 |
# Start training
|
734 |
print("Starting GPT-OSS training...")
|