Tonic commited on
Commit
d60ab6c
·
verified ·
1 Parent(s): 58a74d2

solves oom error with more reasonable configuration

Browse files
config/train_smollm3_openhermes_fr_a100_balanced.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmolLM3 Training Configuration for OpenHermes-FR Dataset - A100 Balanced
3
+ Optimized for good GPU utilization without running out of memory
4
+ """
5
+
6
+ import os
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+ from config.train_smollm3 import SmolLM3Config
10
+
11
+ @dataclass
12
+ class SmolLM3ConfigOpenHermesFRBalanced(SmolLM3Config):
13
+ """Configuration for SmolLM3 fine-tuning with balanced A100 performance"""
14
+
15
+ # Model configuration - balanced for A100
16
+ model_name: str = "HuggingFaceTB/SmolLM3-3B"
17
+ max_seq_length: int = 12288 # Increased but not too much
18
+ use_flash_attention: bool = True
19
+ use_gradient_checkpointing: bool = False # Disabled for A100 efficiency
20
+
21
+ # Training configuration - Balanced GPU utilization
22
+ batch_size: int = 8 # Moderate increase
23
+ gradient_accumulation_steps: int = 16 # Effective batch size = 8 * 16 = 128
24
+ learning_rate: float = 3.5e-6 # Slightly higher for larger effective batch
25
+ weight_decay: float = 0.01
26
+ warmup_steps: int = 1200 # More warmup for larger batch
27
+ max_iters: int = 18000 # More iterations for faster convergence
28
+ eval_interval: int = 1000 # Less frequent evaluation
29
+ log_interval: int = 25 # Less frequent logging
30
+ save_interval: int = 2000 # Less frequent saving
31
+
32
+ # Optimizer configuration - optimized for large batches
33
+ optimizer: str = "adamw_torch"
34
+ beta1: float = 0.9
35
+ beta2: float = 0.999 # Higher beta2 for stability
36
+ eps: float = 1e-8
37
+
38
+ # Scheduler configuration - faster training
39
+ scheduler: str = "cosine"
40
+ min_lr: float = 3.5e-7 # Lower min LR
41
+
42
+ # Mixed precision - A100 optimized
43
+ fp16: bool = False # Use bf16 for A100
44
+ bf16: bool = True # Better for A100
45
+
46
+ # DDP configuration
47
+ ddp_backend: str = "nccl"
48
+ ddp_find_unused_parameters: bool = False
49
+
50
+ # Logging and saving - optimized for fast training
51
+ save_steps: int = 2000
52
+ eval_steps: int = 1000
53
+ logging_steps: int = 25
54
+ save_total_limit: Optional[int] = 5 # Keep fewer checkpoints
55
+
56
+ # Evaluation
57
+ eval_strategy: str = "steps"
58
+ metric_for_best_model: str = "eval_loss"
59
+ greater_is_better: bool = False
60
+ load_best_model_at_end: bool = True
61
+
62
+ # OpenHermes-FR Dataset configuration
63
+ dataset_name: str = "legmlai/openhermes-fr"
64
+ dataset_split: str = "train"
65
+ input_field: str = "prompt"
66
+ target_field: str = "accepted_completion"
67
+ filter_bad_entries: bool = True
68
+ bad_entry_field: str = "bad_entry"
69
+
70
+ # Data configuration (not used for HF datasets but kept for compatibility)
71
+ data_dir: str = None
72
+ train_file: str = None
73
+ validation_file: Optional[str] = None
74
+ test_file: Optional[str] = None
75
+
76
+ # Chat template configuration
77
+ use_chat_template: bool = True
78
+ chat_template_kwargs: dict = None
79
+
80
+ # Trackio monitoring configuration
81
+ enable_tracking: bool = True
82
+ trackio_url: Optional[str] = None
83
+ trackio_token: Optional[str] = None
84
+ log_artifacts: bool = True
85
+ log_metrics: bool = True
86
+ log_config: bool = True
87
+ experiment_name: Optional[str] = None
88
+
89
+ # Additional A100 optimizations for balanced performance
90
+ dataloader_num_workers: int = 10 # More workers for faster data loading
91
+ dataloader_pin_memory: bool = True
92
+ dataloader_prefetch_factor: int = 3 # Increased prefetch
93
+
94
+ # Memory optimizations
95
+ max_grad_norm: float = 1.0 # Gradient clipping
96
+ group_by_length: bool = True # Group similar length sequences
97
+
98
+ # Training duration calculations
99
+ # With 800k datapoints and effective batch size of 128:
100
+ # Steps per epoch = 800,000 / 128 = 6,250 steps
101
+ # For 3 passes: 6,250 * 3 = 18,750 steps
102
+ # Current max_iters = 18,000 (about 2.9 passes)
103
+
104
+ def __post_init__(self):
105
+ if self.chat_template_kwargs is None:
106
+ self.chat_template_kwargs = {
107
+ "enable_thinking": False,
108
+ "add_generation_prompt": True
109
+ }
110
+
111
+ # Validate configuration
112
+ if self.fp16 and self.bf16:
113
+ raise ValueError("Cannot use both fp16 and bf16")
114
+
115
+ if self.max_seq_length > 131072: # 128k limit
116
+ raise ValueError("max_seq_length cannot exceed 131072")
117
+
118
+ # Calculate training statistics
119
+ effective_batch_size = self.batch_size * self.gradient_accumulation_steps
120
+ steps_per_epoch = 800000 // effective_batch_size # Approximate for 800k dataset
121
+ epochs_for_max_iters = self.max_iters / steps_per_epoch
122
+
123
+ print(f"=== A100 Balanced Configuration ===")
124
+ print(f"Effective batch size: {effective_batch_size}")
125
+ print(f"Steps per epoch: ~{steps_per_epoch}")
126
+ print(f"Training for ~{epochs_for_max_iters:.1f} epochs")
127
+ print(f"Total training steps: {self.max_iters}")
128
+ print(f"Learning rate: {self.learning_rate}")
129
+ print(f"Mixed precision: {'bf16' if self.bf16 else 'fp16'}")
130
+ print(f"Max sequence length: {self.max_seq_length}")
131
+ print(f"Gradient checkpointing: {self.use_gradient_checkpointing}")
132
+ print(f"Batch size: {self.batch_size}")
133
+ print(f"Gradient accumulation: {self.gradient_accumulation_steps}")
134
+ print(f"Data loader workers: {self.dataloader_num_workers}")
135
+ print(f"Prefetch factor: {self.dataloader_prefetch_factor}")
136
+ print("=" * 50)
137
+
138
+ # Set default experiment name if not provided
139
+ if self.experiment_name is None:
140
+ self.experiment_name = "smollm3_openhermes_fr_balanced"
141
+
142
+ def get_config(config_path: str) -> SmolLM3ConfigOpenHermesFRBalanced:
143
+ """Load configuration from file or return default"""
144
+ if os.path.exists(config_path):
145
+ # Load from file if it exists
146
+ import importlib.util
147
+ spec = importlib.util.spec_from_file_location("config_module", config_path)
148
+ config_module = importlib.util.module_from_spec(spec)
149
+ spec.loader.exec_module(config_module)
150
+
151
+ if hasattr(config_module, 'config'):
152
+ return config_module.config
153
+ else:
154
+ # Try to find a config class
155
+ for attr_name in dir(config_module):
156
+ attr = getattr(config_module, attr_name)
157
+ if isinstance(attr, SmolLM3ConfigOpenHermesFRBalanced):
158
+ return attr
159
+
160
+ # Return default configuration
161
+ return SmolLM3ConfigOpenHermesFRBalanced()
162
+
163
+ # Default configuration instance
164
+ config = SmolLM3ConfigOpenHermesFRBalanced()
run_a100_large_experiment.py CHANGED
@@ -9,6 +9,9 @@ import os
9
  import sys
10
  from pathlib import Path
11
 
 
 
 
12
  def main():
13
  parser = argparse.ArgumentParser(description="Run A100 large-scale experiments")
14
  parser.add_argument(
 
9
  import sys
10
  from pathlib import Path
11
 
12
+ # Set CUDA memory optimization
13
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
14
+
15
  def main():
16
  parser = argparse.ArgumentParser(description="Run A100 large-scale experiments")
17
  parser.add_argument(