Tonic commited on
Commit
54ebacf
·
verified ·
1 Parent(s): 967ff41

fix config bug

Browse files
config/train_smollm3_h100_lightweight.py CHANGED
@@ -3,112 +3,162 @@ SmolLM3 H100 Lightweight Training Configuration
3
  Optimized for rapid training on H100 with 80K Hermes-FR samples
4
  """
5
 
 
 
 
6
  from config.train_smollm3 import SmolLM3Config
7
 
8
- config = SmolLM3Config(
9
- # Model configuration
10
- model_name="HuggingFaceTB/SmolLM3-3B",
11
- max_seq_length=8192,
12
- use_flash_attention=True,
13
- use_gradient_checkpointing=True,
14
-
15
- # Training configuration - Optimized for H100
16
- batch_size=16, # Larger batch size for H100
17
- gradient_accumulation_steps=4, # Reduced for faster updates
18
- learning_rate=8e-6, # Slightly higher for rapid convergence
19
- weight_decay=0.01,
20
- warmup_steps=50, # Reduced warmup for rapid training
21
- max_iters=None, # Will be calculated based on epochs
22
- eval_interval=50, # More frequent evaluation
23
- log_interval=5, # More frequent logging
24
- save_interval=200, # More frequent saving
25
-
26
- # Optimizer configuration - Optimized for rapid training
27
- optimizer="adamw",
28
- beta1=0.9,
29
- beta2=0.95,
30
- eps=1e-8,
31
-
32
- # Scheduler configuration - Faster learning
33
- scheduler="cosine",
34
- min_lr=2e-6, # Higher minimum LR
 
 
 
35
 
36
  # Mixed precision - Full precision for H100
37
- fp16=True,
38
- bf16=False,
39
 
40
- # Logging and saving - More frequent for rapid training
41
- save_steps=200,
42
- eval_steps=50,
43
- logging_steps=5,
44
- save_total_limit=2, # Keep fewer checkpoints
45
 
46
  # Evaluation
47
- eval_strategy="steps",
48
- metric_for_best_model="eval_loss",
49
- greater_is_better=False,
50
- load_best_model_at_end=True,
51
-
52
- # Data configuration - Hermes-FR with sampling
53
- dataset_name="legmlai/openhermes-fr",
54
- dataset_split="train",
55
- input_field="prompt",
56
- target_field="completion",
57
- filter_bad_entries=False,
58
- bad_entry_field="bad_entry",
59
- sample_size=80000, # 80K samples for lightweight training
60
- sample_seed=42, # For reproducibility
 
 
 
 
 
 
61
 
62
  # Chat template configuration
63
- use_chat_template=True,
64
- chat_template_kwargs={
65
- "enable_thinking": False,
66
- "add_generation_prompt": True,
67
- "no_think_system_message": True
68
- },
69
 
70
  # Trackio monitoring configuration
71
- enable_tracking=True,
72
- trackio_url=None, # Will be set by launch script
73
- trackio_token=None,
74
- log_artifacts=True,
75
- log_metrics=True,
76
- log_config=True,
77
- experiment_name=None, # Will be set by launch script
78
 
79
  # HF Datasets configuration
80
- dataset_repo=None, # Will be set by launch script
 
81
 
82
  # H100-specific optimizations
83
- dataloader_num_workers=4, # Optimized for H100
84
- dataloader_pin_memory=True,
85
- gradient_clipping=1.0, # Prevent gradient explosion
86
 
87
  # Memory optimizations for rapid training
88
- max_grad_norm=1.0,
89
- warmup_ratio=0.1, # 10% warmup
90
- lr_scheduler_type="cosine",
91
-
92
- # Early stopping for rapid training
93
- early_stopping_patience=3,
94
- early_stopping_threshold=0.001,
95
-
96
- # H100-specific training optimizations
97
- remove_unused_columns=False,
98
- group_by_length=True, # Group similar length sequences
99
- length_column_name="length",
100
- ignore_data_skip=False,
101
-
102
- # Reporting
103
- report_to=["tensorboard"],
104
- run_name="smollm3-h100-lightweight",
105
-
106
- # Seed for reproducibility
107
- seed=42,
108
-
109
- # Data collator settings
110
- data_collator_kwargs={
111
- "pad_to_multiple_of": 8, # Optimized for H100
112
- "return_tensors": "pt"
113
- }
114
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  Optimized for rapid training on H100 with 80K Hermes-FR samples
4
  """
5
 
6
+ import os
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
  from config.train_smollm3 import SmolLM3Config
10
 
11
+ @dataclass
12
+ class SmolLM3ConfigH100Lightweight(SmolLM3Config):
13
+ """Configuration for SmolLM3 fine-tuning on OpenHermes-FR dataset - H100 Lightweight"""
14
+
15
+ # Model configuration - optimized for H100
16
+ model_name: str = "HuggingFaceTB/SmolLM3-3B"
17
+ max_seq_length: int = 8192 # Increased for better context understanding
18
+ use_flash_attention: bool = True
19
+ use_gradient_checkpointing: bool = True # Enabled for memory efficiency
20
+
21
+ # Training configuration - H100 optimized for rapid training
22
+ batch_size: int = 16 # Larger batch size for H100
23
+ gradient_accumulation_steps: int = 4 # Reduced for faster updates
24
+ learning_rate: float = 8e-6 # Slightly higher for rapid convergence
25
+ weight_decay: float = 0.01
26
+ warmup_steps: int = 50 # Reduced warmup for rapid training
27
+ max_iters: int = None # Will be calculated based on epochs
28
+ eval_interval: int = 50 # More frequent evaluation
29
+ log_interval: int = 5 # More frequent logging
30
+ save_interval: int = 200 # More frequent saving
31
+
32
+ # Optimizer configuration - optimized for rapid training
33
+ optimizer: str = "adamw_torch"
34
+ beta1: float = 0.9
35
+ beta2: float = 0.95
36
+ eps: float = 1e-8
37
+
38
+ # Scheduler configuration - faster learning
39
+ scheduler: str = "cosine"
40
+ min_lr: float = 2e-6 # Higher minimum LR
41
 
42
  # Mixed precision - Full precision for H100
43
+ fp16: bool = True
44
+ bf16: bool = False
45
 
46
+ # Logging and saving - more frequent for rapid training
47
+ save_steps: int = 200
48
+ eval_steps: int = 50
49
+ logging_steps: int = 5
50
+ save_total_limit: Optional[int] = 2 # Keep fewer checkpoints
51
 
52
  # Evaluation
53
+ eval_strategy: str = "steps"
54
+ metric_for_best_model: str = "eval_loss"
55
+ greater_is_better: bool = False
56
+ load_best_model_at_end: bool = True
57
+
58
+ # OpenHermes-FR Dataset configuration with sampling
59
+ dataset_name: str = "legmlai/openhermes-fr"
60
+ dataset_split: str = "train"
61
+ input_field: str = "prompt"
62
+ target_field: str = "completion"
63
+ filter_bad_entries: bool = False
64
+ bad_entry_field: str = "bad_entry"
65
+ sample_size: int = 80000 # 80K samples for lightweight training
66
+ sample_seed: int = 42 # For reproducibility
67
+
68
+ # Data configuration (not used for HF datasets but kept for compatibility)
69
+ data_dir: str = "my_dataset"
70
+ train_file: str = "train.json"
71
+ validation_file: Optional[str] = "validation.json"
72
+ test_file: Optional[str] = None
73
 
74
  # Chat template configuration
75
+ use_chat_template: bool = True
76
+ chat_template_kwargs: dict = None
 
 
 
 
77
 
78
  # Trackio monitoring configuration
79
+ enable_tracking: bool = True
80
+ trackio_url: Optional[str] = None
81
+ trackio_token: Optional[str] = None
82
+ log_artifacts: bool = True
83
+ log_metrics: bool = True
84
+ log_config: bool = True
85
+ experiment_name: Optional[str] = None
86
 
87
  # HF Datasets configuration
88
+ hf_token: Optional[str] = None
89
+ dataset_repo: Optional[str] = None
90
 
91
  # H100-specific optimizations
92
+ dataloader_num_workers: int = 4 # Optimized for H100
93
+ dataloader_pin_memory: bool = True
94
+ dataloader_prefetch_factor: int = 2
95
 
96
  # Memory optimizations for rapid training
97
+ max_grad_norm: float = 1.0
98
+ group_by_length: bool = True # Group similar length sequences
99
+
100
+ # Training duration calculations
101
+ # With 80k datapoints and effective batch size of 64:
102
+ # Steps per epoch = 80,000 / 64 = 1,250 steps
103
+ # For 1 epoch: 1,250 steps
104
+ # For 2 epochs: 2,500 steps
105
+
106
+ def __post_init__(self):
107
+ if self.chat_template_kwargs is None:
108
+ self.chat_template_kwargs = {
109
+ "enable_thinking": False,
110
+ "add_generation_prompt": True,
111
+ "no_think_system_message": True
112
+ }
113
+
114
+ # Validate configuration
115
+ if self.fp16 and self.bf16:
116
+ raise ValueError("Cannot use both fp16 and bf16")
117
+
118
+ if self.max_seq_length > 131072: # 128k limit
119
+ raise ValueError("max_seq_length cannot exceed 131072")
120
+
121
+ # Calculate training statistics
122
+ effective_batch_size = self.batch_size * self.gradient_accumulation_steps
123
+ steps_per_epoch = self.sample_size // effective_batch_size # For 80k dataset
124
+ epochs_for_max_iters = self.max_iters / steps_per_epoch if self.max_iters else 1
125
+
126
+ print(f"=== H100 Lightweight Training Configuration ===")
127
+ print(f"Effective batch size: {effective_batch_size}")
128
+ print(f"Steps per epoch: ~{steps_per_epoch}")
129
+ print(f"Training for ~{epochs_for_max_iters:.1f} epochs")
130
+ print(f"Total training steps: {self.max_iters or 'auto'}")
131
+ print(f"Learning rate: {self.learning_rate}")
132
+ print(f"Mixed precision: {'fp16' if self.fp16 else 'bf16'}")
133
+ print(f"Max sequence length: {self.max_seq_length}")
134
+ print(f"Gradient checkpointing: {self.use_gradient_checkpointing}")
135
+ print(f"Dataset sample size: {self.sample_size}")
136
+ print("=" * 50)
137
+
138
+ # Set default experiment name if not provided
139
+ if self.experiment_name is None:
140
+ self.experiment_name = "smollm3_h100_lightweight"
141
+
142
+ def get_config(config_path: str) -> SmolLM3ConfigH100Lightweight:
143
+ """Load configuration from file or return default"""
144
+ if os.path.exists(config_path):
145
+ # Load from file if it exists
146
+ import importlib.util
147
+ spec = importlib.util.spec_from_file_location("config_module", config_path)
148
+ config_module = importlib.util.module_from_spec(spec)
149
+ spec.loader.exec_module(config_module)
150
+
151
+ if hasattr(config_module, 'config'):
152
+ return config_module.config
153
+ else:
154
+ # Try to find a config class
155
+ for attr_name in dir(config_module):
156
+ attr = getattr(config_module, attr_name)
157
+ if isinstance(attr, SmolLM3ConfigH100Lightweight):
158
+ return attr
159
+
160
+ # Return default configuration
161
+ return SmolLM3ConfigH100Lightweight()
162
+
163
+ # Default configuration instance
164
+ config = SmolLM3ConfigH100Lightweight()
scripts/training/train.py CHANGED
@@ -53,6 +53,12 @@ def main():
53
  type=str,
54
  help="Trackio token for authentication"
55
  )
 
 
 
 
 
 
56
 
57
  args = parser.parse_args()
58
 
@@ -65,13 +71,13 @@ def main():
65
  # Import all available configurations
66
  from config.train_smollm3_openhermes_fr_a100_large import get_config as get_large_config
67
  from config.train_smollm3_openhermes_fr_a100_multiple_passes import get_config as get_multiple_passes_config
68
- from config.train_smollm3_h100_lightweight import config as h100_lightweight_config
69
 
70
  # Map config files to their respective functions
71
  config_map = {
72
  "config/train_smollm3_openhermes_fr_a100_large.py": get_large_config,
73
  "config/train_smollm3_openhermes_fr_a100_multiple_passes.py": get_multiple_passes_config,
74
- "config/train_smollm3_h100_lightweight.py": lambda x: h100_lightweight_config,
75
  }
76
 
77
  if args.config in config_map:
@@ -116,7 +122,15 @@ def main():
116
  print(f"Max iterations: {config.max_iters}")
117
  print(f"Max sequence length: {config.max_seq_length}")
118
  print(f"Mixed precision: {'bf16' if config.bf16 else 'fp16'}")
119
- print(f"Dataset: {config.dataset_name}")
 
 
 
 
 
 
 
 
120
  if config.trackio_url:
121
  print(f"Trackio URL: {config.trackio_url}")
122
  if config.trackio_token:
@@ -151,6 +165,9 @@ def main():
151
  if args.experiment_name:
152
  train_args.extend(["--experiment_name", args.experiment_name])
153
 
 
 
 
154
  # Override sys.argv for the training script
155
  original_argv = sys.argv
156
  sys.argv = ["train.py"] + train_args
 
53
  type=str,
54
  help="Trackio token for authentication"
55
  )
56
+ parser.add_argument(
57
+ "--dataset-dir",
58
+ type=str,
59
+ default="my_dataset",
60
+ help="Dataset directory path"
61
+ )
62
 
63
  args = parser.parse_args()
64
 
 
71
  # Import all available configurations
72
  from config.train_smollm3_openhermes_fr_a100_large import get_config as get_large_config
73
  from config.train_smollm3_openhermes_fr_a100_multiple_passes import get_config as get_multiple_passes_config
74
+ from config.train_smollm3_h100_lightweight import get_config as get_h100_lightweight_config
75
 
76
  # Map config files to their respective functions
77
  config_map = {
78
  "config/train_smollm3_openhermes_fr_a100_large.py": get_large_config,
79
  "config/train_smollm3_openhermes_fr_a100_multiple_passes.py": get_multiple_passes_config,
80
+ "config/train_smollm3_h100_lightweight.py": get_h100_lightweight_config,
81
  }
82
 
83
  if args.config in config_map:
 
122
  print(f"Max iterations: {config.max_iters}")
123
  print(f"Max sequence length: {config.max_seq_length}")
124
  print(f"Mixed precision: {'bf16' if config.bf16 else 'fp16'}")
125
+ if hasattr(config, 'dataset_name') and config.dataset_name:
126
+ print(f"Dataset: {config.dataset_name}")
127
+ if hasattr(config, 'sample_size') and config.sample_size:
128
+ print(f"Sample size: {config.sample_size}")
129
+ else:
130
+ print(f"Dataset directory: {config.data_dir}")
131
+ print(f"Training file: {config.train_file}")
132
+ if config.validation_file:
133
+ print(f"Validation file: {config.validation_file}")
134
  if config.trackio_url:
135
  print(f"Trackio URL: {config.trackio_url}")
136
  if config.trackio_token:
 
165
  if args.experiment_name:
166
  train_args.extend(["--experiment_name", args.experiment_name])
167
 
168
+ # Add dataset directory argument
169
+ train_args.extend(["--dataset_dir", args.dataset_dir])
170
+
171
  # Override sys.argv for the training script
172
  original_argv = sys.argv
173
  sys.argv = ["train.py"] + train_args
src/train.py CHANGED
@@ -174,13 +174,17 @@ def main():
174
  )
175
 
176
  # Determine dataset path
 
177
  if hasattr(config, 'dataset_name') and config.dataset_name:
178
  # Use Hugging Face dataset
179
  dataset_path = config.dataset_name
180
  logger.info(f"Using Hugging Face dataset: {dataset_path}")
181
  else:
182
- # Use local dataset
183
- dataset_path = os.path.join('/input', args.dataset_dir)
 
 
 
184
  logger.info(f"Using local dataset: {dataset_path}")
185
 
186
  # Load dataset with filtering options and sampling
 
174
  )
175
 
176
  # Determine dataset path
177
+ # Check if using Hugging Face dataset or local dataset
178
  if hasattr(config, 'dataset_name') and config.dataset_name:
179
  # Use Hugging Face dataset
180
  dataset_path = config.dataset_name
181
  logger.info(f"Using Hugging Face dataset: {dataset_path}")
182
  else:
183
+ # Use local dataset from config or command line argument
184
+ if args.dataset_dir:
185
+ dataset_path = os.path.join('/input', args.dataset_dir)
186
+ else:
187
+ dataset_path = os.path.join('/input', config.data_dir)
188
  logger.info(f"Using local dataset: {dataset_path}")
189
 
190
  # Load dataset with filtering options and sampling
test_config.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script to verify H100 lightweight configuration loads correctly
4
+ """
5
+
6
+ import sys
7
+ import os
8
+
9
+ # Add project root to path
10
+ project_root = os.path.dirname(os.path.abspath(__file__))
11
+ sys.path.insert(0, project_root)
12
+
13
+ def test_h100_lightweight_config():
14
+ """Test the H100 lightweight configuration"""
15
+ try:
16
+ from config.train_smollm3_h100_lightweight import config
17
+
18
+ print("✅ H100 Lightweight configuration loaded successfully!")
19
+ print(f"Model: {config.model_name}")
20
+ print(f"Dataset: {config.dataset_name}")
21
+ print(f"Sample size: {config.sample_size}")
22
+ print(f"Batch size: {config.batch_size}")
23
+ print(f"Learning rate: {config.learning_rate}")
24
+ print(f"Max sequence length: {config.max_seq_length}")
25
+
26
+ return True
27
+ except Exception as e:
28
+ print(f"❌ Error loading H100 lightweight configuration: {e}")
29
+ return False
30
+
31
+ def test_training_script_import():
32
+ """Test that the training script can import the configuration"""
33
+ try:
34
+ from scripts.training.train import main
35
+ print("✅ Training script imports successfully!")
36
+ return True
37
+ except Exception as e:
38
+ print(f"❌ Error importing training script: {e}")
39
+ return False
40
+
41
+ if __name__ == "__main__":
42
+ print("Testing H100 Lightweight Configuration...")
43
+ print("=" * 50)
44
+
45
+ success = True
46
+ success &= test_h100_lightweight_config()
47
+ success &= test_training_script_import()
48
+
49
+ if success:
50
+ print("\n🎉 All tests passed! Configuration is ready for training.")
51
+ else:
52
+ print("\n❌ Some tests failed. Please check the configuration.")
53
+ sys.exit(1)