Spaces:
Running
Running
fix config bug
Browse files- config/train_smollm3_h100_lightweight.py +142 -92
- scripts/training/train.py +20 -3
- src/train.py +6 -2
- test_config.py +53 -0
config/train_smollm3_h100_lightweight.py
CHANGED
@@ -3,112 +3,162 @@ SmolLM3 H100 Lightweight Training Configuration
|
|
3 |
Optimized for rapid training on H100 with 80K Hermes-FR samples
|
4 |
"""
|
5 |
|
|
|
|
|
|
|
6 |
from config.train_smollm3 import SmolLM3Config
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
35 |
|
36 |
# Mixed precision - Full precision for H100
|
37 |
-
fp16=True
|
38 |
-
bf16=False
|
39 |
|
40 |
-
# Logging and saving -
|
41 |
-
save_steps=200
|
42 |
-
eval_steps=50
|
43 |
-
logging_steps=5
|
44 |
-
save_total_limit=2
|
45 |
|
46 |
# Evaluation
|
47 |
-
eval_strategy="steps"
|
48 |
-
metric_for_best_model="eval_loss"
|
49 |
-
greater_is_better=False
|
50 |
-
load_best_model_at_end=True
|
51 |
-
|
52 |
-
#
|
53 |
-
dataset_name="legmlai/openhermes-fr"
|
54 |
-
dataset_split="train"
|
55 |
-
input_field="prompt"
|
56 |
-
target_field="completion"
|
57 |
-
filter_bad_entries=False
|
58 |
-
bad_entry_field="bad_entry"
|
59 |
-
sample_size=80000
|
60 |
-
sample_seed=42
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
# Chat template configuration
|
63 |
-
use_chat_template=True
|
64 |
-
chat_template_kwargs=
|
65 |
-
"enable_thinking": False,
|
66 |
-
"add_generation_prompt": True,
|
67 |
-
"no_think_system_message": True
|
68 |
-
},
|
69 |
|
70 |
# Trackio monitoring configuration
|
71 |
-
enable_tracking=True
|
72 |
-
trackio_url
|
73 |
-
trackio_token=None
|
74 |
-
log_artifacts=True
|
75 |
-
log_metrics=True
|
76 |
-
log_config=True
|
77 |
-
experiment_name
|
78 |
|
79 |
# HF Datasets configuration
|
80 |
-
|
|
|
81 |
|
82 |
# H100-specific optimizations
|
83 |
-
dataloader_num_workers=4
|
84 |
-
dataloader_pin_memory=True
|
85 |
-
|
86 |
|
87 |
# Memory optimizations for rapid training
|
88 |
-
max_grad_norm=1.0
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
#
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
Optimized for rapid training on H100 with 80K Hermes-FR samples
|
4 |
"""
|
5 |
|
6 |
+
import os
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Optional
|
9 |
from config.train_smollm3 import SmolLM3Config
|
10 |
|
11 |
+
@dataclass
|
12 |
+
class SmolLM3ConfigH100Lightweight(SmolLM3Config):
|
13 |
+
"""Configuration for SmolLM3 fine-tuning on OpenHermes-FR dataset - H100 Lightweight"""
|
14 |
+
|
15 |
+
# Model configuration - optimized for H100
|
16 |
+
model_name: str = "HuggingFaceTB/SmolLM3-3B"
|
17 |
+
max_seq_length: int = 8192 # Increased for better context understanding
|
18 |
+
use_flash_attention: bool = True
|
19 |
+
use_gradient_checkpointing: bool = True # Enabled for memory efficiency
|
20 |
+
|
21 |
+
# Training configuration - H100 optimized for rapid training
|
22 |
+
batch_size: int = 16 # Larger batch size for H100
|
23 |
+
gradient_accumulation_steps: int = 4 # Reduced for faster updates
|
24 |
+
learning_rate: float = 8e-6 # Slightly higher for rapid convergence
|
25 |
+
weight_decay: float = 0.01
|
26 |
+
warmup_steps: int = 50 # Reduced warmup for rapid training
|
27 |
+
max_iters: int = None # Will be calculated based on epochs
|
28 |
+
eval_interval: int = 50 # More frequent evaluation
|
29 |
+
log_interval: int = 5 # More frequent logging
|
30 |
+
save_interval: int = 200 # More frequent saving
|
31 |
+
|
32 |
+
# Optimizer configuration - optimized for rapid training
|
33 |
+
optimizer: str = "adamw_torch"
|
34 |
+
beta1: float = 0.9
|
35 |
+
beta2: float = 0.95
|
36 |
+
eps: float = 1e-8
|
37 |
+
|
38 |
+
# Scheduler configuration - faster learning
|
39 |
+
scheduler: str = "cosine"
|
40 |
+
min_lr: float = 2e-6 # Higher minimum LR
|
41 |
|
42 |
# Mixed precision - Full precision for H100
|
43 |
+
fp16: bool = True
|
44 |
+
bf16: bool = False
|
45 |
|
46 |
+
# Logging and saving - more frequent for rapid training
|
47 |
+
save_steps: int = 200
|
48 |
+
eval_steps: int = 50
|
49 |
+
logging_steps: int = 5
|
50 |
+
save_total_limit: Optional[int] = 2 # Keep fewer checkpoints
|
51 |
|
52 |
# Evaluation
|
53 |
+
eval_strategy: str = "steps"
|
54 |
+
metric_for_best_model: str = "eval_loss"
|
55 |
+
greater_is_better: bool = False
|
56 |
+
load_best_model_at_end: bool = True
|
57 |
+
|
58 |
+
# OpenHermes-FR Dataset configuration with sampling
|
59 |
+
dataset_name: str = "legmlai/openhermes-fr"
|
60 |
+
dataset_split: str = "train"
|
61 |
+
input_field: str = "prompt"
|
62 |
+
target_field: str = "completion"
|
63 |
+
filter_bad_entries: bool = False
|
64 |
+
bad_entry_field: str = "bad_entry"
|
65 |
+
sample_size: int = 80000 # 80K samples for lightweight training
|
66 |
+
sample_seed: int = 42 # For reproducibility
|
67 |
+
|
68 |
+
# Data configuration (not used for HF datasets but kept for compatibility)
|
69 |
+
data_dir: str = "my_dataset"
|
70 |
+
train_file: str = "train.json"
|
71 |
+
validation_file: Optional[str] = "validation.json"
|
72 |
+
test_file: Optional[str] = None
|
73 |
|
74 |
# Chat template configuration
|
75 |
+
use_chat_template: bool = True
|
76 |
+
chat_template_kwargs: dict = None
|
|
|
|
|
|
|
|
|
77 |
|
78 |
# Trackio monitoring configuration
|
79 |
+
enable_tracking: bool = True
|
80 |
+
trackio_url: Optional[str] = None
|
81 |
+
trackio_token: Optional[str] = None
|
82 |
+
log_artifacts: bool = True
|
83 |
+
log_metrics: bool = True
|
84 |
+
log_config: bool = True
|
85 |
+
experiment_name: Optional[str] = None
|
86 |
|
87 |
# HF Datasets configuration
|
88 |
+
hf_token: Optional[str] = None
|
89 |
+
dataset_repo: Optional[str] = None
|
90 |
|
91 |
# H100-specific optimizations
|
92 |
+
dataloader_num_workers: int = 4 # Optimized for H100
|
93 |
+
dataloader_pin_memory: bool = True
|
94 |
+
dataloader_prefetch_factor: int = 2
|
95 |
|
96 |
# Memory optimizations for rapid training
|
97 |
+
max_grad_norm: float = 1.0
|
98 |
+
group_by_length: bool = True # Group similar length sequences
|
99 |
+
|
100 |
+
# Training duration calculations
|
101 |
+
# With 80k datapoints and effective batch size of 64:
|
102 |
+
# Steps per epoch = 80,000 / 64 = 1,250 steps
|
103 |
+
# For 1 epoch: 1,250 steps
|
104 |
+
# For 2 epochs: 2,500 steps
|
105 |
+
|
106 |
+
def __post_init__(self):
|
107 |
+
if self.chat_template_kwargs is None:
|
108 |
+
self.chat_template_kwargs = {
|
109 |
+
"enable_thinking": False,
|
110 |
+
"add_generation_prompt": True,
|
111 |
+
"no_think_system_message": True
|
112 |
+
}
|
113 |
+
|
114 |
+
# Validate configuration
|
115 |
+
if self.fp16 and self.bf16:
|
116 |
+
raise ValueError("Cannot use both fp16 and bf16")
|
117 |
+
|
118 |
+
if self.max_seq_length > 131072: # 128k limit
|
119 |
+
raise ValueError("max_seq_length cannot exceed 131072")
|
120 |
+
|
121 |
+
# Calculate training statistics
|
122 |
+
effective_batch_size = self.batch_size * self.gradient_accumulation_steps
|
123 |
+
steps_per_epoch = self.sample_size // effective_batch_size # For 80k dataset
|
124 |
+
epochs_for_max_iters = self.max_iters / steps_per_epoch if self.max_iters else 1
|
125 |
+
|
126 |
+
print(f"=== H100 Lightweight Training Configuration ===")
|
127 |
+
print(f"Effective batch size: {effective_batch_size}")
|
128 |
+
print(f"Steps per epoch: ~{steps_per_epoch}")
|
129 |
+
print(f"Training for ~{epochs_for_max_iters:.1f} epochs")
|
130 |
+
print(f"Total training steps: {self.max_iters or 'auto'}")
|
131 |
+
print(f"Learning rate: {self.learning_rate}")
|
132 |
+
print(f"Mixed precision: {'fp16' if self.fp16 else 'bf16'}")
|
133 |
+
print(f"Max sequence length: {self.max_seq_length}")
|
134 |
+
print(f"Gradient checkpointing: {self.use_gradient_checkpointing}")
|
135 |
+
print(f"Dataset sample size: {self.sample_size}")
|
136 |
+
print("=" * 50)
|
137 |
+
|
138 |
+
# Set default experiment name if not provided
|
139 |
+
if self.experiment_name is None:
|
140 |
+
self.experiment_name = "smollm3_h100_lightweight"
|
141 |
+
|
142 |
+
def get_config(config_path: str) -> SmolLM3ConfigH100Lightweight:
|
143 |
+
"""Load configuration from file or return default"""
|
144 |
+
if os.path.exists(config_path):
|
145 |
+
# Load from file if it exists
|
146 |
+
import importlib.util
|
147 |
+
spec = importlib.util.spec_from_file_location("config_module", config_path)
|
148 |
+
config_module = importlib.util.module_from_spec(spec)
|
149 |
+
spec.loader.exec_module(config_module)
|
150 |
+
|
151 |
+
if hasattr(config_module, 'config'):
|
152 |
+
return config_module.config
|
153 |
+
else:
|
154 |
+
# Try to find a config class
|
155 |
+
for attr_name in dir(config_module):
|
156 |
+
attr = getattr(config_module, attr_name)
|
157 |
+
if isinstance(attr, SmolLM3ConfigH100Lightweight):
|
158 |
+
return attr
|
159 |
+
|
160 |
+
# Return default configuration
|
161 |
+
return SmolLM3ConfigH100Lightweight()
|
162 |
+
|
163 |
+
# Default configuration instance
|
164 |
+
config = SmolLM3ConfigH100Lightweight()
|
scripts/training/train.py
CHANGED
@@ -53,6 +53,12 @@ def main():
|
|
53 |
type=str,
|
54 |
help="Trackio token for authentication"
|
55 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
args = parser.parse_args()
|
58 |
|
@@ -65,13 +71,13 @@ def main():
|
|
65 |
# Import all available configurations
|
66 |
from config.train_smollm3_openhermes_fr_a100_large import get_config as get_large_config
|
67 |
from config.train_smollm3_openhermes_fr_a100_multiple_passes import get_config as get_multiple_passes_config
|
68 |
-
from config.train_smollm3_h100_lightweight import
|
69 |
|
70 |
# Map config files to their respective functions
|
71 |
config_map = {
|
72 |
"config/train_smollm3_openhermes_fr_a100_large.py": get_large_config,
|
73 |
"config/train_smollm3_openhermes_fr_a100_multiple_passes.py": get_multiple_passes_config,
|
74 |
-
"config/train_smollm3_h100_lightweight.py":
|
75 |
}
|
76 |
|
77 |
if args.config in config_map:
|
@@ -116,7 +122,15 @@ def main():
|
|
116 |
print(f"Max iterations: {config.max_iters}")
|
117 |
print(f"Max sequence length: {config.max_seq_length}")
|
118 |
print(f"Mixed precision: {'bf16' if config.bf16 else 'fp16'}")
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
if config.trackio_url:
|
121 |
print(f"Trackio URL: {config.trackio_url}")
|
122 |
if config.trackio_token:
|
@@ -151,6 +165,9 @@ def main():
|
|
151 |
if args.experiment_name:
|
152 |
train_args.extend(["--experiment_name", args.experiment_name])
|
153 |
|
|
|
|
|
|
|
154 |
# Override sys.argv for the training script
|
155 |
original_argv = sys.argv
|
156 |
sys.argv = ["train.py"] + train_args
|
|
|
53 |
type=str,
|
54 |
help="Trackio token for authentication"
|
55 |
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--dataset-dir",
|
58 |
+
type=str,
|
59 |
+
default="my_dataset",
|
60 |
+
help="Dataset directory path"
|
61 |
+
)
|
62 |
|
63 |
args = parser.parse_args()
|
64 |
|
|
|
71 |
# Import all available configurations
|
72 |
from config.train_smollm3_openhermes_fr_a100_large import get_config as get_large_config
|
73 |
from config.train_smollm3_openhermes_fr_a100_multiple_passes import get_config as get_multiple_passes_config
|
74 |
+
from config.train_smollm3_h100_lightweight import get_config as get_h100_lightweight_config
|
75 |
|
76 |
# Map config files to their respective functions
|
77 |
config_map = {
|
78 |
"config/train_smollm3_openhermes_fr_a100_large.py": get_large_config,
|
79 |
"config/train_smollm3_openhermes_fr_a100_multiple_passes.py": get_multiple_passes_config,
|
80 |
+
"config/train_smollm3_h100_lightweight.py": get_h100_lightweight_config,
|
81 |
}
|
82 |
|
83 |
if args.config in config_map:
|
|
|
122 |
print(f"Max iterations: {config.max_iters}")
|
123 |
print(f"Max sequence length: {config.max_seq_length}")
|
124 |
print(f"Mixed precision: {'bf16' if config.bf16 else 'fp16'}")
|
125 |
+
if hasattr(config, 'dataset_name') and config.dataset_name:
|
126 |
+
print(f"Dataset: {config.dataset_name}")
|
127 |
+
if hasattr(config, 'sample_size') and config.sample_size:
|
128 |
+
print(f"Sample size: {config.sample_size}")
|
129 |
+
else:
|
130 |
+
print(f"Dataset directory: {config.data_dir}")
|
131 |
+
print(f"Training file: {config.train_file}")
|
132 |
+
if config.validation_file:
|
133 |
+
print(f"Validation file: {config.validation_file}")
|
134 |
if config.trackio_url:
|
135 |
print(f"Trackio URL: {config.trackio_url}")
|
136 |
if config.trackio_token:
|
|
|
165 |
if args.experiment_name:
|
166 |
train_args.extend(["--experiment_name", args.experiment_name])
|
167 |
|
168 |
+
# Add dataset directory argument
|
169 |
+
train_args.extend(["--dataset_dir", args.dataset_dir])
|
170 |
+
|
171 |
# Override sys.argv for the training script
|
172 |
original_argv = sys.argv
|
173 |
sys.argv = ["train.py"] + train_args
|
src/train.py
CHANGED
@@ -174,13 +174,17 @@ def main():
|
|
174 |
)
|
175 |
|
176 |
# Determine dataset path
|
|
|
177 |
if hasattr(config, 'dataset_name') and config.dataset_name:
|
178 |
# Use Hugging Face dataset
|
179 |
dataset_path = config.dataset_name
|
180 |
logger.info(f"Using Hugging Face dataset: {dataset_path}")
|
181 |
else:
|
182 |
-
# Use local dataset
|
183 |
-
|
|
|
|
|
|
|
184 |
logger.info(f"Using local dataset: {dataset_path}")
|
185 |
|
186 |
# Load dataset with filtering options and sampling
|
|
|
174 |
)
|
175 |
|
176 |
# Determine dataset path
|
177 |
+
# Check if using Hugging Face dataset or local dataset
|
178 |
if hasattr(config, 'dataset_name') and config.dataset_name:
|
179 |
# Use Hugging Face dataset
|
180 |
dataset_path = config.dataset_name
|
181 |
logger.info(f"Using Hugging Face dataset: {dataset_path}")
|
182 |
else:
|
183 |
+
# Use local dataset from config or command line argument
|
184 |
+
if args.dataset_dir:
|
185 |
+
dataset_path = os.path.join('/input', args.dataset_dir)
|
186 |
+
else:
|
187 |
+
dataset_path = os.path.join('/input', config.data_dir)
|
188 |
logger.info(f"Using local dataset: {dataset_path}")
|
189 |
|
190 |
# Load dataset with filtering options and sampling
|
test_config.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Test script to verify H100 lightweight configuration loads correctly
|
4 |
+
"""
|
5 |
+
|
6 |
+
import sys
|
7 |
+
import os
|
8 |
+
|
9 |
+
# Add project root to path
|
10 |
+
project_root = os.path.dirname(os.path.abspath(__file__))
|
11 |
+
sys.path.insert(0, project_root)
|
12 |
+
|
13 |
+
def test_h100_lightweight_config():
|
14 |
+
"""Test the H100 lightweight configuration"""
|
15 |
+
try:
|
16 |
+
from config.train_smollm3_h100_lightweight import config
|
17 |
+
|
18 |
+
print("✅ H100 Lightweight configuration loaded successfully!")
|
19 |
+
print(f"Model: {config.model_name}")
|
20 |
+
print(f"Dataset: {config.dataset_name}")
|
21 |
+
print(f"Sample size: {config.sample_size}")
|
22 |
+
print(f"Batch size: {config.batch_size}")
|
23 |
+
print(f"Learning rate: {config.learning_rate}")
|
24 |
+
print(f"Max sequence length: {config.max_seq_length}")
|
25 |
+
|
26 |
+
return True
|
27 |
+
except Exception as e:
|
28 |
+
print(f"❌ Error loading H100 lightweight configuration: {e}")
|
29 |
+
return False
|
30 |
+
|
31 |
+
def test_training_script_import():
|
32 |
+
"""Test that the training script can import the configuration"""
|
33 |
+
try:
|
34 |
+
from scripts.training.train import main
|
35 |
+
print("✅ Training script imports successfully!")
|
36 |
+
return True
|
37 |
+
except Exception as e:
|
38 |
+
print(f"❌ Error importing training script: {e}")
|
39 |
+
return False
|
40 |
+
|
41 |
+
if __name__ == "__main__":
|
42 |
+
print("Testing H100 Lightweight Configuration...")
|
43 |
+
print("=" * 50)
|
44 |
+
|
45 |
+
success = True
|
46 |
+
success &= test_h100_lightweight_config()
|
47 |
+
success &= test_training_script_import()
|
48 |
+
|
49 |
+
if success:
|
50 |
+
print("\n🎉 All tests passed! Configuration is ready for training.")
|
51 |
+
else:
|
52 |
+
print("\n❌ Some tests failed. Please check the configuration.")
|
53 |
+
sys.exit(1)
|