Spaces:
Running
Running
attempt to fix bfloat16 issue
Browse files- config/train_smollm3_h100_lightweight.py +4 -3
- quick_test_training.py +60 -0
- src/model.py +29 -6
- test_mixed_precision.py +63 -0
- test_training_fix.py +62 -0
config/train_smollm3_h100_lightweight.py
CHANGED
@@ -39,9 +39,10 @@ class SmolLM3ConfigH100Lightweight(SmolLM3Config):
|
|
39 |
scheduler: str = "cosine"
|
40 |
min_lr: float = 2e-6 # Higher minimum LR
|
41 |
|
42 |
-
# Mixed precision -
|
43 |
-
|
44 |
-
|
|
|
45 |
|
46 |
# Logging and saving - more frequent for rapid training
|
47 |
save_steps: int = 200
|
|
|
39 |
scheduler: str = "cosine"
|
40 |
min_lr: float = 2e-6 # Higher minimum LR
|
41 |
|
42 |
+
# Mixed precision - Using fp16 for better compatibility
|
43 |
+
# Note: bf16 can cause issues on some GPU setups, fp16 is more universally supported
|
44 |
+
fp16: bool = False
|
45 |
+
bf16: bool = True
|
46 |
|
47 |
# Logging and saving - more frequent for rapid training
|
48 |
save_steps: int = 200
|
quick_test_training.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Quick test for the training fix
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
|
9 |
+
# Add project root to path
|
10 |
+
project_root = os.path.dirname(os.path.abspath(__file__))
|
11 |
+
sys.path.insert(0, project_root)
|
12 |
+
|
13 |
+
def main():
|
14 |
+
print("π§ Testing H100 Lightweight Training Fix")
|
15 |
+
print("=" * 50)
|
16 |
+
|
17 |
+
# Set environment variables to fix mixed precision issues
|
18 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
19 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
20 |
+
os.environ["TORCH_USE_CUDA_DSA"] = "1"
|
21 |
+
|
22 |
+
print("β
Environment variables set")
|
23 |
+
|
24 |
+
# Test configuration
|
25 |
+
try:
|
26 |
+
from config.train_smollm3_h100_lightweight import SmolLM3ConfigH100Lightweight
|
27 |
+
config = SmolLM3ConfigH100Lightweight()
|
28 |
+
print(f"β
Configuration loaded: fp16={config.fp16}, bf16={config.bf16}")
|
29 |
+
|
30 |
+
# Test model loading (without actually loading the full model)
|
31 |
+
from src.model import SmolLM3Model
|
32 |
+
|
33 |
+
# Create model instance
|
34 |
+
model = SmolLM3Model(
|
35 |
+
model_name="HuggingFaceTB/SmolLM3-3B",
|
36 |
+
max_seq_length=4096,
|
37 |
+
config=config
|
38 |
+
)
|
39 |
+
|
40 |
+
print(f"β
Model dtype: {model.torch_dtype}")
|
41 |
+
print(f"β
Model device map: {model.device_map}")
|
42 |
+
|
43 |
+
# Test training arguments
|
44 |
+
training_args = model.get_training_arguments("/tmp/test")
|
45 |
+
print(f"β
Training args: fp16={training_args.fp16}, bf16={training_args.bf16}")
|
46 |
+
|
47 |
+
print("\nπ All tests passed!")
|
48 |
+
print("You can now run the training with:")
|
49 |
+
print(" ./launch.sh")
|
50 |
+
|
51 |
+
except Exception as e:
|
52 |
+
print(f"β Error: {e}")
|
53 |
+
import traceback
|
54 |
+
traceback.print_exc()
|
55 |
+
return 1
|
56 |
+
|
57 |
+
return 0
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
exit(main())
|
src/model.py
CHANGED
@@ -36,7 +36,16 @@ class SmolLM3Model:
|
|
36 |
# Set device and dtype
|
37 |
if torch_dtype is None:
|
38 |
if torch.cuda.is_available():
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
else:
|
41 |
self.torch_dtype = torch.float32
|
42 |
else:
|
@@ -110,11 +119,25 @@ class SmolLM3Model:
|
|
110 |
# If flash attention is not supported, skip it
|
111 |
pass
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
# Enable gradient checkpointing if specified
|
120 |
if self.config and self.config.use_gradient_checkpointing:
|
|
|
36 |
# Set device and dtype
|
37 |
if torch_dtype is None:
|
38 |
if torch.cuda.is_available():
|
39 |
+
# Check if config specifies mixed precision
|
40 |
+
if config and hasattr(config, 'fp16') and config.fp16:
|
41 |
+
# Use fp16 if explicitly configured
|
42 |
+
self.torch_dtype = torch.float16
|
43 |
+
elif config and hasattr(config, 'bf16') and config.bf16:
|
44 |
+
# Use bf16 if explicitly configured
|
45 |
+
self.torch_dtype = torch.bfloat16
|
46 |
+
else:
|
47 |
+
# Default to bfp16 for better compatibility
|
48 |
+
self.torch_dtype = torch.bfloat16
|
49 |
else:
|
50 |
self.torch_dtype = torch.float32
|
51 |
else:
|
|
|
119 |
# If flash attention is not supported, skip it
|
120 |
pass
|
121 |
|
122 |
+
# Try to load the model, fallback to fp16 if bf16 fails
|
123 |
+
try:
|
124 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
125 |
+
self.model_name,
|
126 |
+
config=model_config,
|
127 |
+
**model_kwargs
|
128 |
+
)
|
129 |
+
except RuntimeError as e:
|
130 |
+
if "bfloat16" in str(e) or "BFloat16" in str(e):
|
131 |
+
logger.warning("BFloat16 not supported, falling back to Float16")
|
132 |
+
model_kwargs["torch_dtype"] = torch.float16
|
133 |
+
self.torch_dtype = torch.float16
|
134 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
135 |
+
self.model_name,
|
136 |
+
config=model_config,
|
137 |
+
**model_kwargs
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
raise
|
141 |
|
142 |
# Enable gradient checkpointing if specified
|
143 |
if self.config and self.config.use_gradient_checkpointing:
|
test_mixed_precision.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Test script to verify mixed precision configuration
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
|
10 |
+
# Add project root to path
|
11 |
+
project_root = os.path.dirname(os.path.abspath(__file__))
|
12 |
+
sys.path.insert(0, project_root)
|
13 |
+
|
14 |
+
def test_mixed_precision():
|
15 |
+
"""Test mixed precision configuration"""
|
16 |
+
print("Testing mixed precision configuration...")
|
17 |
+
|
18 |
+
# Test 1: Check CUDA availability
|
19 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
print(f"CUDA device count: {torch.cuda.device_count()}")
|
22 |
+
print(f"CUDA device capability: {torch.cuda.get_device_capability()}")
|
23 |
+
print(f"Current device: {torch.cuda.current_device()}")
|
24 |
+
|
25 |
+
# Test 2: Test model loading with different dtypes
|
26 |
+
try:
|
27 |
+
from src.model import SmolLM3Model
|
28 |
+
from config.train_smollm3_h100_lightweight import SmolLM3ConfigH100Lightweight
|
29 |
+
|
30 |
+
config = SmolLM3ConfigH100Lightweight()
|
31 |
+
print(f"Config fp16: {config.fp16}")
|
32 |
+
print(f"Config bf16: {config.bf16}")
|
33 |
+
|
34 |
+
# Test model loading
|
35 |
+
model = SmolLM3Model(
|
36 |
+
model_name="HuggingFaceTB/SmolLM3-3B",
|
37 |
+
max_seq_length=4096,
|
38 |
+
config=config
|
39 |
+
)
|
40 |
+
|
41 |
+
print(f"Model dtype: {model.torch_dtype}")
|
42 |
+
print(f"Model device map: {model.device_map}")
|
43 |
+
print("β
Model loading successful!")
|
44 |
+
|
45 |
+
# Test training arguments
|
46 |
+
training_args = model.get_training_arguments("/tmp/test")
|
47 |
+
print(f"Training args fp16: {training_args.fp16}")
|
48 |
+
print(f"Training args bf16: {training_args.bf16}")
|
49 |
+
print("β
Training arguments created successfully!")
|
50 |
+
|
51 |
+
except Exception as e:
|
52 |
+
print(f"β Error: {e}")
|
53 |
+
return False
|
54 |
+
|
55 |
+
return True
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
success = test_mixed_precision()
|
59 |
+
if success:
|
60 |
+
print("\nπ Mixed precision test passed!")
|
61 |
+
else:
|
62 |
+
print("\nβ Mixed precision test failed!")
|
63 |
+
sys.exit(1)
|
test_training_fix.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Quick test to verify the training configuration fix
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
|
9 |
+
# Add project root to path
|
10 |
+
project_root = os.path.dirname(os.path.abspath(__file__))
|
11 |
+
sys.path.insert(0, project_root)
|
12 |
+
|
13 |
+
def test_configuration():
|
14 |
+
"""Test the H100 lightweight configuration"""
|
15 |
+
print("Testing H100 Lightweight Configuration...")
|
16 |
+
|
17 |
+
try:
|
18 |
+
from config.train_smollm3_h100_lightweight import SmolLM3ConfigH100Lightweight
|
19 |
+
|
20 |
+
config = SmolLM3ConfigH100Lightweight()
|
21 |
+
|
22 |
+
print("β
Configuration loaded successfully")
|
23 |
+
print(f" Model: {config.model_name}")
|
24 |
+
print(f" Batch size: {config.batch_size}")
|
25 |
+
print(f" Learning rate: {config.learning_rate}")
|
26 |
+
print(f" FP16: {config.fp16}")
|
27 |
+
print(f" BF16: {config.bf16}")
|
28 |
+
print(f" Mixed precision: {'fp16' if config.fp16 else 'bf16'}")
|
29 |
+
print(f" Sample size: {config.sample_size}")
|
30 |
+
|
31 |
+
# Test training arguments creation
|
32 |
+
from src.model import SmolLM3Model
|
33 |
+
|
34 |
+
# Create a minimal model instance for testing
|
35 |
+
model = SmolLM3Model(
|
36 |
+
model_name="HuggingFaceTB/SmolLM3-3B",
|
37 |
+
max_seq_length=4096,
|
38 |
+
config=config
|
39 |
+
)
|
40 |
+
|
41 |
+
# Test training arguments
|
42 |
+
training_args = model.get_training_arguments("/tmp/test")
|
43 |
+
print(f"β
Training arguments created successfully")
|
44 |
+
print(f" Training args FP16: {training_args.fp16}")
|
45 |
+
print(f" Training args BF16: {training_args.bf16}")
|
46 |
+
|
47 |
+
return True
|
48 |
+
|
49 |
+
except Exception as e:
|
50 |
+
print(f"β Error: {e}")
|
51 |
+
import traceback
|
52 |
+
traceback.print_exc()
|
53 |
+
return False
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
success = test_configuration()
|
57 |
+
if success:
|
58 |
+
print("\nπ Configuration test passed!")
|
59 |
+
print("You can now run the training with: ./launch.sh")
|
60 |
+
else:
|
61 |
+
print("\nβ Configuration test failed!")
|
62 |
+
sys.exit(1)
|