Tonic commited on
Commit
6c63876
Β·
verified Β·
1 Parent(s): c33a1e3

attempt to fix bfloat16 issue

Browse files
config/train_smollm3_h100_lightweight.py CHANGED
@@ -39,9 +39,10 @@ class SmolLM3ConfigH100Lightweight(SmolLM3Config):
39
  scheduler: str = "cosine"
40
  min_lr: float = 2e-6 # Higher minimum LR
41
 
42
- # Mixed precision - Full precision for H100
43
- fp16: bool = True
44
- bf16: bool = False
 
45
 
46
  # Logging and saving - more frequent for rapid training
47
  save_steps: int = 200
 
39
  scheduler: str = "cosine"
40
  min_lr: float = 2e-6 # Higher minimum LR
41
 
42
+ # Mixed precision - Using fp16 for better compatibility
43
+ # Note: bf16 can cause issues on some GPU setups, fp16 is more universally supported
44
+ fp16: bool = False
45
+ bf16: bool = True
46
 
47
  # Logging and saving - more frequent for rapid training
48
  save_steps: int = 200
quick_test_training.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick test for the training fix
4
+ """
5
+
6
+ import os
7
+ import sys
8
+
9
+ # Add project root to path
10
+ project_root = os.path.dirname(os.path.abspath(__file__))
11
+ sys.path.insert(0, project_root)
12
+
13
+ def main():
14
+ print("πŸ”§ Testing H100 Lightweight Training Fix")
15
+ print("=" * 50)
16
+
17
+ # Set environment variables to fix mixed precision issues
18
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
19
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
20
+ os.environ["TORCH_USE_CUDA_DSA"] = "1"
21
+
22
+ print("βœ… Environment variables set")
23
+
24
+ # Test configuration
25
+ try:
26
+ from config.train_smollm3_h100_lightweight import SmolLM3ConfigH100Lightweight
27
+ config = SmolLM3ConfigH100Lightweight()
28
+ print(f"βœ… Configuration loaded: fp16={config.fp16}, bf16={config.bf16}")
29
+
30
+ # Test model loading (without actually loading the full model)
31
+ from src.model import SmolLM3Model
32
+
33
+ # Create model instance
34
+ model = SmolLM3Model(
35
+ model_name="HuggingFaceTB/SmolLM3-3B",
36
+ max_seq_length=4096,
37
+ config=config
38
+ )
39
+
40
+ print(f"βœ… Model dtype: {model.torch_dtype}")
41
+ print(f"βœ… Model device map: {model.device_map}")
42
+
43
+ # Test training arguments
44
+ training_args = model.get_training_arguments("/tmp/test")
45
+ print(f"βœ… Training args: fp16={training_args.fp16}, bf16={training_args.bf16}")
46
+
47
+ print("\nπŸŽ‰ All tests passed!")
48
+ print("You can now run the training with:")
49
+ print(" ./launch.sh")
50
+
51
+ except Exception as e:
52
+ print(f"❌ Error: {e}")
53
+ import traceback
54
+ traceback.print_exc()
55
+ return 1
56
+
57
+ return 0
58
+
59
+ if __name__ == "__main__":
60
+ exit(main())
src/model.py CHANGED
@@ -36,7 +36,16 @@ class SmolLM3Model:
36
  # Set device and dtype
37
  if torch_dtype is None:
38
  if torch.cuda.is_available():
39
- self.torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
 
 
 
 
 
 
 
 
 
40
  else:
41
  self.torch_dtype = torch.float32
42
  else:
@@ -110,11 +119,25 @@ class SmolLM3Model:
110
  # If flash attention is not supported, skip it
111
  pass
112
 
113
- self.model = AutoModelForCausalLM.from_pretrained(
114
- self.model_name,
115
- config=model_config,
116
- **model_kwargs
117
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  # Enable gradient checkpointing if specified
120
  if self.config and self.config.use_gradient_checkpointing:
 
36
  # Set device and dtype
37
  if torch_dtype is None:
38
  if torch.cuda.is_available():
39
+ # Check if config specifies mixed precision
40
+ if config and hasattr(config, 'fp16') and config.fp16:
41
+ # Use fp16 if explicitly configured
42
+ self.torch_dtype = torch.float16
43
+ elif config and hasattr(config, 'bf16') and config.bf16:
44
+ # Use bf16 if explicitly configured
45
+ self.torch_dtype = torch.bfloat16
46
+ else:
47
+ # Default to bfp16 for better compatibility
48
+ self.torch_dtype = torch.bfloat16
49
  else:
50
  self.torch_dtype = torch.float32
51
  else:
 
119
  # If flash attention is not supported, skip it
120
  pass
121
 
122
+ # Try to load the model, fallback to fp16 if bf16 fails
123
+ try:
124
+ self.model = AutoModelForCausalLM.from_pretrained(
125
+ self.model_name,
126
+ config=model_config,
127
+ **model_kwargs
128
+ )
129
+ except RuntimeError as e:
130
+ if "bfloat16" in str(e) or "BFloat16" in str(e):
131
+ logger.warning("BFloat16 not supported, falling back to Float16")
132
+ model_kwargs["torch_dtype"] = torch.float16
133
+ self.torch_dtype = torch.float16
134
+ self.model = AutoModelForCausalLM.from_pretrained(
135
+ self.model_name,
136
+ config=model_config,
137
+ **model_kwargs
138
+ )
139
+ else:
140
+ raise
141
 
142
  # Enable gradient checkpointing if specified
143
  if self.config and self.config.use_gradient_checkpointing:
test_mixed_precision.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script to verify mixed precision configuration
4
+ """
5
+
6
+ import torch
7
+ import sys
8
+ import os
9
+
10
+ # Add project root to path
11
+ project_root = os.path.dirname(os.path.abspath(__file__))
12
+ sys.path.insert(0, project_root)
13
+
14
+ def test_mixed_precision():
15
+ """Test mixed precision configuration"""
16
+ print("Testing mixed precision configuration...")
17
+
18
+ # Test 1: Check CUDA availability
19
+ print(f"CUDA available: {torch.cuda.is_available()}")
20
+ if torch.cuda.is_available():
21
+ print(f"CUDA device count: {torch.cuda.device_count()}")
22
+ print(f"CUDA device capability: {torch.cuda.get_device_capability()}")
23
+ print(f"Current device: {torch.cuda.current_device()}")
24
+
25
+ # Test 2: Test model loading with different dtypes
26
+ try:
27
+ from src.model import SmolLM3Model
28
+ from config.train_smollm3_h100_lightweight import SmolLM3ConfigH100Lightweight
29
+
30
+ config = SmolLM3ConfigH100Lightweight()
31
+ print(f"Config fp16: {config.fp16}")
32
+ print(f"Config bf16: {config.bf16}")
33
+
34
+ # Test model loading
35
+ model = SmolLM3Model(
36
+ model_name="HuggingFaceTB/SmolLM3-3B",
37
+ max_seq_length=4096,
38
+ config=config
39
+ )
40
+
41
+ print(f"Model dtype: {model.torch_dtype}")
42
+ print(f"Model device map: {model.device_map}")
43
+ print("βœ… Model loading successful!")
44
+
45
+ # Test training arguments
46
+ training_args = model.get_training_arguments("/tmp/test")
47
+ print(f"Training args fp16: {training_args.fp16}")
48
+ print(f"Training args bf16: {training_args.bf16}")
49
+ print("βœ… Training arguments created successfully!")
50
+
51
+ except Exception as e:
52
+ print(f"❌ Error: {e}")
53
+ return False
54
+
55
+ return True
56
+
57
+ if __name__ == "__main__":
58
+ success = test_mixed_precision()
59
+ if success:
60
+ print("\nπŸŽ‰ Mixed precision test passed!")
61
+ else:
62
+ print("\n❌ Mixed precision test failed!")
63
+ sys.exit(1)
test_training_fix.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick test to verify the training configuration fix
4
+ """
5
+
6
+ import os
7
+ import sys
8
+
9
+ # Add project root to path
10
+ project_root = os.path.dirname(os.path.abspath(__file__))
11
+ sys.path.insert(0, project_root)
12
+
13
+ def test_configuration():
14
+ """Test the H100 lightweight configuration"""
15
+ print("Testing H100 Lightweight Configuration...")
16
+
17
+ try:
18
+ from config.train_smollm3_h100_lightweight import SmolLM3ConfigH100Lightweight
19
+
20
+ config = SmolLM3ConfigH100Lightweight()
21
+
22
+ print("βœ… Configuration loaded successfully")
23
+ print(f" Model: {config.model_name}")
24
+ print(f" Batch size: {config.batch_size}")
25
+ print(f" Learning rate: {config.learning_rate}")
26
+ print(f" FP16: {config.fp16}")
27
+ print(f" BF16: {config.bf16}")
28
+ print(f" Mixed precision: {'fp16' if config.fp16 else 'bf16'}")
29
+ print(f" Sample size: {config.sample_size}")
30
+
31
+ # Test training arguments creation
32
+ from src.model import SmolLM3Model
33
+
34
+ # Create a minimal model instance for testing
35
+ model = SmolLM3Model(
36
+ model_name="HuggingFaceTB/SmolLM3-3B",
37
+ max_seq_length=4096,
38
+ config=config
39
+ )
40
+
41
+ # Test training arguments
42
+ training_args = model.get_training_arguments("/tmp/test")
43
+ print(f"βœ… Training arguments created successfully")
44
+ print(f" Training args FP16: {training_args.fp16}")
45
+ print(f" Training args BF16: {training_args.bf16}")
46
+
47
+ return True
48
+
49
+ except Exception as e:
50
+ print(f"❌ Error: {e}")
51
+ import traceback
52
+ traceback.print_exc()
53
+ return False
54
+
55
+ if __name__ == "__main__":
56
+ success = test_configuration()
57
+ if success:
58
+ print("\nπŸŽ‰ Configuration test passed!")
59
+ print("You can now run the training with: ./launch.sh")
60
+ else:
61
+ print("\n❌ Configuration test failed!")
62
+ sys.exit(1)