Tonic commited on
Commit
fa7de39
·
1 Parent(s): 7181190

adds quantization configuration correctly

Browse files
config/train_gpt_oss_memory_optimized.py CHANGED
@@ -89,11 +89,8 @@ class GPTOSSMemoryOptimizedConfig:
89
 
90
  if self.quantization_config is None:
91
  self.quantization_config = {
92
- "dequantize": True,
93
- "load_in_4bit": True,
94
- "bnb_4bit_compute_dtype": "bfloat16",
95
- "bnb_4bit_use_double_quant": True,
96
- "bnb_4bit_quant_type": "nf4"
97
  }
98
 
99
  if self.model_kwargs is None:
 
89
 
90
  if self.quantization_config is None:
91
  self.quantization_config = {
92
+ "dequantize": True, # Use Mxfp4Config as per tutorial
93
+ "load_in_4bit": False # Only use 4-bit if explicitly needed
 
 
 
94
  }
95
 
96
  if self.model_kwargs is None:
requirements/requirements_core.txt CHANGED
@@ -1,6 +1,6 @@
1
  # Core dependencies for SmolLM3 and GPT-OSS fine-tuning
2
  torch>=2.0.0
3
- transformers>=4.55.0 # Updated for GPT-OSS compatibility
4
  datasets>=2.14.0
5
  accelerate>=0.20.0
6
  peft>=0.17.0 # Updated for GPT-OSS LoRA support
 
1
  # Core dependencies for SmolLM3 and GPT-OSS fine-tuning
2
  torch>=2.0.0
3
+ transformers @ git+https://github.com/huggingface/transformers.git # Latest version with GPT-OSS support
4
  datasets>=2.14.0
5
  accelerate>=0.20.0
6
  peft>=0.17.0 # Updated for GPT-OSS LoRA support
scripts/training/train_gpt_oss.py CHANGED
@@ -27,26 +27,38 @@ def load_gpt_oss_model_and_tokenizer(config):
27
 
28
  # Set up quantization config based on config
29
  if config.quantization_config and config.quantization_config.get("load_in_4bit"):
30
- # Use BitsAndBytesConfig for 4-bit quantization
31
  quantization_config = BitsAndBytesConfig(
32
  load_in_4bit=True,
33
  bnb_4bit_compute_dtype=torch.bfloat16,
34
  bnb_4bit_use_double_quant=True,
35
  bnb_4bit_quant_type="nf4"
36
  )
 
 
 
 
 
 
 
 
 
37
  else:
38
- # Use BitsAndBytesConfig as default (no quantization)
39
  quantization_config = None
40
 
41
  # Model kwargs as per tutorial
42
  model_kwargs = {
43
  "attn_implementation": "eager",
44
  "torch_dtype": torch.bfloat16,
45
- "quantization_config": quantization_config,
46
  "use_cache": False,
47
  "device_map": "auto",
48
  }
49
 
 
 
 
 
50
  model = AutoModelForCausalLM.from_pretrained(config.model_name, **model_kwargs)
51
 
52
  return model, tokenizer
 
27
 
28
  # Set up quantization config based on config
29
  if config.quantization_config and config.quantization_config.get("load_in_4bit"):
30
+ # Use BitsAndBytesConfig for 4-bit quantization (memory optimized)
31
  quantization_config = BitsAndBytesConfig(
32
  load_in_4bit=True,
33
  bnb_4bit_compute_dtype=torch.bfloat16,
34
  bnb_4bit_use_double_quant=True,
35
  bnb_4bit_quant_type="nf4"
36
  )
37
+ elif config.quantization_config and config.quantization_config.get("dequantize"):
38
+ # Try to use Mxfp4Config if available (as per tutorial)
39
+ try:
40
+ from transformers import Mxfp4Config
41
+ quantization_config = Mxfp4Config(dequantize=True)
42
+ except ImportError:
43
+ # Fallback to no quantization if Mxfp4Config not available
44
+ print("Warning: Mxfp4Config not available, using no quantization")
45
+ quantization_config = None
46
  else:
47
+ # No quantization
48
  quantization_config = None
49
 
50
  # Model kwargs as per tutorial
51
  model_kwargs = {
52
  "attn_implementation": "eager",
53
  "torch_dtype": torch.bfloat16,
 
54
  "use_cache": False,
55
  "device_map": "auto",
56
  }
57
 
58
+ # Only add quantization_config if it's not None
59
+ if quantization_config is not None:
60
+ model_kwargs["quantization_config"] = quantization_config
61
+
62
  model = AutoModelForCausalLM.from_pretrained(config.model_name, **model_kwargs)
63
 
64
  return model, tokenizer