Spaces:
Running
Running
adds defensive programming (boo) and adaptations based on transformer versions
Browse files
scripts/training/train_gpt_oss.py
CHANGED
@@ -8,6 +8,7 @@ Based on the GPT-OSS fine-tuning tutorial
|
|
8 |
import os
|
9 |
import sys
|
10 |
import argparse
|
|
|
11 |
import torch
|
12 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
13 |
from peft import LoraConfig, get_peft_model
|
@@ -386,62 +387,72 @@ def create_sft_config(config, output_dir):
|
|
386 |
print(f" • Gradient accumulation: {gradient_accumulation_steps}")
|
387 |
print(f" • Effective batch size: {per_device_train_batch_size * gradient_accumulation_steps}")
|
388 |
|
389 |
-
|
|
|
390 |
# Training duration
|
391 |
-
num_train_epochs
|
392 |
-
max_steps
|
393 |
-
|
394 |
# Learning rate
|
395 |
-
learning_rate
|
396 |
-
lr_scheduler_type
|
397 |
-
warmup_ratio
|
398 |
-
warmup_steps
|
399 |
-
|
400 |
# Batch configuration
|
401 |
-
per_device_train_batch_size
|
402 |
-
per_device_eval_batch_size
|
403 |
-
gradient_accumulation_steps
|
404 |
-
|
405 |
# Model configuration
|
406 |
-
gradient_checkpointing
|
407 |
-
|
408 |
# Mixed precision
|
409 |
-
fp16
|
410 |
-
bf16
|
411 |
-
|
412 |
# Regularization
|
413 |
-
weight_decay
|
414 |
-
max_grad_norm
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
eval_steps=eval_steps,
|
419 |
-
|
420 |
# Logging
|
421 |
-
logging_steps
|
422 |
-
|
423 |
# Saving
|
424 |
-
save_strategy
|
425 |
-
save_steps
|
426 |
-
save_total_limit
|
427 |
-
|
428 |
# Output
|
429 |
-
output_dir
|
430 |
-
|
431 |
# Data loading
|
432 |
-
dataloader_num_workers
|
433 |
-
dataloader_pin_memory
|
434 |
-
|
435 |
# Performance
|
436 |
-
group_by_length
|
437 |
-
remove_unused_columns
|
438 |
-
|
439 |
# HuggingFace Hub
|
440 |
-
push_to_hub
|
441 |
-
|
442 |
# Monitoring
|
443 |
-
report_to
|
444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
|
446 |
return sft_config
|
447 |
|
|
|
8 |
import os
|
9 |
import sys
|
10 |
import argparse
|
11 |
+
import inspect
|
12 |
import torch
|
13 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
14 |
from peft import LoraConfig, get_peft_model
|
|
|
387 |
print(f" • Gradient accumulation: {gradient_accumulation_steps}")
|
388 |
print(f" • Effective batch size: {per_device_train_batch_size * gradient_accumulation_steps}")
|
389 |
|
390 |
+
# Build kwargs dynamically to be compatible across transformers versions
|
391 |
+
ta_kwargs = {
|
392 |
# Training duration
|
393 |
+
"num_train_epochs": num_train_epochs,
|
394 |
+
"max_steps": max_steps,
|
|
|
395 |
# Learning rate
|
396 |
+
"learning_rate": learning_rate,
|
397 |
+
"lr_scheduler_type": lr_scheduler_type,
|
398 |
+
"warmup_ratio": warmup_ratio,
|
399 |
+
"warmup_steps": warmup_steps,
|
|
|
400 |
# Batch configuration
|
401 |
+
"per_device_train_batch_size": per_device_train_batch_size,
|
402 |
+
"per_device_eval_batch_size": per_device_eval_batch_size,
|
403 |
+
"gradient_accumulation_steps": gradient_accumulation_steps,
|
|
|
404 |
# Model configuration
|
405 |
+
"gradient_checkpointing": getattr(config, 'use_gradient_checkpointing', True),
|
|
|
406 |
# Mixed precision
|
407 |
+
"fp16": fp16,
|
408 |
+
"bf16": bf16,
|
|
|
409 |
# Regularization
|
410 |
+
"weight_decay": weight_decay,
|
411 |
+
"max_grad_norm": max_grad_norm,
|
412 |
+
# Evaluation (name may vary across versions)
|
413 |
+
"evaluation_strategy": eval_strategy,
|
414 |
+
"eval_steps": eval_steps,
|
|
|
|
|
415 |
# Logging
|
416 |
+
"logging_steps": logging_steps,
|
|
|
417 |
# Saving
|
418 |
+
"save_strategy": save_strategy,
|
419 |
+
"save_steps": save_steps,
|
420 |
+
"save_total_limit": save_total_limit,
|
|
|
421 |
# Output
|
422 |
+
"output_dir": output_dir,
|
|
|
423 |
# Data loading
|
424 |
+
"dataloader_num_workers": getattr(config, 'dataloader_num_workers', 4),
|
425 |
+
"dataloader_pin_memory": getattr(config, 'dataloader_pin_memory', True),
|
|
|
426 |
# Performance
|
427 |
+
"group_by_length": getattr(config, 'group_by_length', True),
|
428 |
+
"remove_unused_columns": getattr(config, 'remove_unused_columns', True),
|
|
|
429 |
# HuggingFace Hub
|
430 |
+
"push_to_hub": push_to_hub,
|
|
|
431 |
# Monitoring
|
432 |
+
"report_to": ("trackio" if getattr(config, 'enable_tracking', False) else None),
|
433 |
+
}
|
434 |
+
|
435 |
+
# Adapt to transformers versions where 'evaluation_strategy' was renamed
|
436 |
+
try:
|
437 |
+
ta_sig = inspect.signature(TrainingArguments.__init__)
|
438 |
+
param_names = set(ta_sig.parameters.keys())
|
439 |
+
except Exception:
|
440 |
+
param_names = set()
|
441 |
+
|
442 |
+
if "evaluation_strategy" not in param_names and "eval_strategy" in param_names:
|
443 |
+
# Move value to 'eval_strategy'
|
444 |
+
ta_kwargs["eval_strategy"] = ta_kwargs.pop("evaluation_strategy")
|
445 |
+
elif "evaluation_strategy" not in param_names:
|
446 |
+
# If neither is supported, drop it
|
447 |
+
ta_kwargs.pop("evaluation_strategy", None)
|
448 |
+
|
449 |
+
# Remove any kwargs not supported by current transformers version
|
450 |
+
if param_names:
|
451 |
+
unsupported = [k for k in ta_kwargs.keys() if k not in param_names]
|
452 |
+
for k in unsupported:
|
453 |
+
ta_kwargs.pop(k, None)
|
454 |
+
|
455 |
+
sft_config = TrainingArguments(**ta_kwargs)
|
456 |
|
457 |
return sft_config
|
458 |
|