Tonic commited on
Commit
58a74d2
·
verified ·
1 Parent(s): a2e0482

adds max perf script

Browse files
config/__init__.py CHANGED
@@ -6,6 +6,7 @@ from .train_smollm3 import SmolLM3Config, get_config as get_base_config
6
  from .train_smollm3_openhermes_fr import SmolLM3ConfigOpenHermesFR, get_config as get_openhermes_fr_config
7
  from .train_smollm3_openhermes_fr_a100_large import SmolLM3ConfigOpenHermesFRA100Large, get_config as get_a100_large_config
8
  from .train_smollm3_openhermes_fr_a100_multiple_passes import SmolLM3ConfigOpenHermesFRMultiplePasses, get_config as get_multiple_passes_config
 
9
 
10
  # Generic get_config function that can handle different config types
11
  def get_config(config_path: str):
@@ -20,6 +21,8 @@ def get_config(config_path: str):
20
  return get_a100_large_config(config_path)
21
  elif "a100_multiple_passes" in config_path:
22
  return get_multiple_passes_config(config_path)
 
 
23
  elif "openhermes_fr" in config_path:
24
  return get_openhermes_fr_config(config_path)
25
  else:
@@ -30,9 +33,11 @@ __all__ = [
30
  'SmolLM3ConfigOpenHermesFR',
31
  'SmolLM3ConfigOpenHermesFRA100Large',
32
  'SmolLM3ConfigOpenHermesFRMultiplePasses',
 
33
  'get_config',
34
  'get_base_config',
35
  'get_openhermes_fr_config',
36
  'get_a100_large_config',
37
  'get_multiple_passes_config',
 
38
  ]
 
6
  from .train_smollm3_openhermes_fr import SmolLM3ConfigOpenHermesFR, get_config as get_openhermes_fr_config
7
  from .train_smollm3_openhermes_fr_a100_large import SmolLM3ConfigOpenHermesFRA100Large, get_config as get_a100_large_config
8
  from .train_smollm3_openhermes_fr_a100_multiple_passes import SmolLM3ConfigOpenHermesFRMultiplePasses, get_config as get_multiple_passes_config
9
+ from .train_smollm3_openhermes_fr_a100_max_performance import SmolLM3ConfigOpenHermesFRMaxPerformance, get_config as get_max_performance_config
10
 
11
  # Generic get_config function that can handle different config types
12
  def get_config(config_path: str):
 
21
  return get_a100_large_config(config_path)
22
  elif "a100_multiple_passes" in config_path:
23
  return get_multiple_passes_config(config_path)
24
+ elif "a100_max_performance" in config_path:
25
+ return get_max_performance_config(config_path)
26
  elif "openhermes_fr" in config_path:
27
  return get_openhermes_fr_config(config_path)
28
  else:
 
33
  'SmolLM3ConfigOpenHermesFR',
34
  'SmolLM3ConfigOpenHermesFRA100Large',
35
  'SmolLM3ConfigOpenHermesFRMultiplePasses',
36
+ 'SmolLM3ConfigOpenHermesFRMaxPerformance',
37
  'get_config',
38
  'get_base_config',
39
  'get_openhermes_fr_config',
40
  'get_a100_large_config',
41
  'get_multiple_passes_config',
42
+ 'get_max_performance_config',
43
  ]
config/train_smollm3_openhermes_fr_a100_max_performance.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmolLM3 Training Configuration for OpenHermes-FR Dataset - A100 Max Performance
3
+ Optimized for maximum GPU utilization and fastest training on A100
4
+ """
5
+
6
+ import os
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+ from config.train_smollm3 import SmolLM3Config
10
+
11
+ @dataclass
12
+ class SmolLM3ConfigOpenHermesFRMaxPerformance(SmolLM3Config):
13
+ """Configuration for SmolLM3 fine-tuning with maximum A100 performance"""
14
+
15
+ # Model configuration - optimized for A100
16
+ model_name: str = "HuggingFaceTB/SmolLM3-3B"
17
+ max_seq_length: int = 16384 # Increased for better GPU utilization
18
+ use_flash_attention: bool = True
19
+ use_gradient_checkpointing: bool = False # Disabled for A100 efficiency
20
+
21
+ # Training configuration - Maximum GPU utilization
22
+ batch_size: int = 12 # Increased batch size for A100
23
+ gradient_accumulation_steps: int = 12 # Effective batch size = 12 * 12 = 144
24
+ learning_rate: float = 4e-6 # Slightly higher for larger effective batch
25
+ weight_decay: float = 0.01
26
+ warmup_steps: int = 1500 # More warmup for larger batch
27
+ max_iters: int = 20000 # More iterations for faster convergence
28
+ eval_interval: int = 1000 # Less frequent evaluation
29
+ log_interval: int = 25 # Less frequent logging
30
+ save_interval: int = 2000 # Less frequent saving
31
+
32
+ # Optimizer configuration - optimized for large batches
33
+ optimizer: str = "adamw_torch"
34
+ beta1: float = 0.9
35
+ beta2: float = 0.999 # Higher beta2 for stability
36
+ eps: float = 1e-8
37
+
38
+ # Scheduler configuration - faster training
39
+ scheduler: str = "cosine"
40
+ min_lr: float = 4e-7 # Lower min LR
41
+
42
+ # Mixed precision - A100 optimized
43
+ fp16: bool = False # Use bf16 for A100
44
+ bf16: bool = True # Better for A100
45
+
46
+ # DDP configuration
47
+ ddp_backend: str = "nccl"
48
+ ddp_find_unused_parameters: bool = False
49
+
50
+ # Logging and saving - optimized for fast training
51
+ save_steps: int = 2000
52
+ eval_steps: int = 1000
53
+ logging_steps: int = 25
54
+ save_total_limit: Optional[int] = 5 # Keep fewer checkpoints
55
+
56
+ # Evaluation
57
+ eval_strategy: str = "steps"
58
+ metric_for_best_model: str = "eval_loss"
59
+ greater_is_better: bool = False
60
+ load_best_model_at_end: bool = True
61
+
62
+ # OpenHermes-FR Dataset configuration
63
+ dataset_name: str = "legmlai/openhermes-fr"
64
+ dataset_split: str = "train"
65
+ input_field: str = "prompt"
66
+ target_field: str = "accepted_completion"
67
+ filter_bad_entries: bool = True
68
+ bad_entry_field: str = "bad_entry"
69
+
70
+ # Data configuration (not used for HF datasets but kept for compatibility)
71
+ data_dir: str = None
72
+ train_file: str = None
73
+ validation_file: Optional[str] = None
74
+ test_file: Optional[str] = None
75
+
76
+ # Chat template configuration
77
+ use_chat_template: bool = True
78
+ chat_template_kwargs: dict = None
79
+
80
+ # Trackio monitoring configuration
81
+ enable_tracking: bool = True
82
+ trackio_url: Optional[str] = None
83
+ trackio_token: Optional[str] = None
84
+ log_artifacts: bool = True
85
+ log_metrics: bool = True
86
+ log_config: bool = True
87
+ experiment_name: Optional[str] = None
88
+
89
+ # Additional A100 optimizations for maximum performance
90
+ dataloader_num_workers: int = 12 # More workers for faster data loading
91
+ dataloader_pin_memory: bool = True
92
+ dataloader_prefetch_factor: int = 4 # Increased prefetch
93
+
94
+ # Memory optimizations
95
+ max_grad_norm: float = 1.0 # Gradient clipping
96
+ group_by_length: bool = True # Group similar length sequences
97
+
98
+ # Training duration calculations
99
+ # With 800k datapoints and effective batch size of 144:
100
+ # Steps per epoch = 800,000 / 144 = 5,556 steps
101
+ # For 3 passes: 5,556 * 3 = 16,667 steps
102
+ # For 4 passes: 5,556 * 4 = 22,222 steps
103
+ # Current max_iters = 20,000 (about 3.6 passes)
104
+
105
+ def __post_init__(self):
106
+ if self.chat_template_kwargs is None:
107
+ self.chat_template_kwargs = {
108
+ "enable_thinking": False,
109
+ "add_generation_prompt": True
110
+ }
111
+
112
+ # Validate configuration
113
+ if self.fp16 and self.bf16:
114
+ raise ValueError("Cannot use both fp16 and bf16")
115
+
116
+ if self.max_seq_length > 131072: # 128k limit
117
+ raise ValueError("max_seq_length cannot exceed 131072")
118
+
119
+ # Calculate training statistics
120
+ effective_batch_size = self.batch_size * self.gradient_accumulation_steps
121
+ steps_per_epoch = 800000 // effective_batch_size # Approximate for 800k dataset
122
+ epochs_for_max_iters = self.max_iters / steps_per_epoch
123
+
124
+ print(f"=== A100 Max Performance Configuration ===")
125
+ print(f"Effective batch size: {effective_batch_size}")
126
+ print(f"Steps per epoch: ~{steps_per_epoch}")
127
+ print(f"Training for ~{epochs_for_max_iters:.1f} epochs")
128
+ print(f"Total training steps: {self.max_iters}")
129
+ print(f"Learning rate: {self.learning_rate}")
130
+ print(f"Mixed precision: {'bf16' if self.bf16 else 'fp16'}")
131
+ print(f"Max sequence length: {self.max_seq_length}")
132
+ print(f"Gradient checkpointing: {self.use_gradient_checkpointing}")
133
+ print(f"Batch size: {self.batch_size}")
134
+ print(f"Gradient accumulation: {self.gradient_accumulation_steps}")
135
+ print(f"Data loader workers: {self.dataloader_num_workers}")
136
+ print(f"Prefetch factor: {self.dataloader_prefetch_factor}")
137
+ print("=" * 50)
138
+
139
+ # Set default experiment name if not provided
140
+ if self.experiment_name is None:
141
+ self.experiment_name = "smollm3_openhermes_fr_max_performance"
142
+
143
+ def get_config(config_path: str) -> SmolLM3ConfigOpenHermesFRMaxPerformance:
144
+ """Load configuration from file or return default"""
145
+ if os.path.exists(config_path):
146
+ # Load from file if it exists
147
+ import importlib.util
148
+ spec = importlib.util.spec_from_file_location("config_module", config_path)
149
+ config_module = importlib.util.module_from_spec(spec)
150
+ spec.loader.exec_module(config_module)
151
+
152
+ if hasattr(config_module, 'config'):
153
+ return config_module.config
154
+ else:
155
+ # Try to find a config class
156
+ for attr_name in dir(config_module):
157
+ attr = getattr(config_module, attr_name)
158
+ if isinstance(attr, SmolLM3ConfigOpenHermesFRMaxPerformance):
159
+ return attr
160
+
161
+ # Return default configuration
162
+ return SmolLM3ConfigOpenHermesFRMaxPerformance()
163
+
164
+ # Default configuration instance
165
+ config = SmolLM3ConfigOpenHermesFRMaxPerformance()