Tonic commited on
Commit
07eab17
·
verified ·
1 Parent(s): 829d8f4

removes flash attention 2

Browse files
Files changed (1) hide show
  1. model.py +19 -5
model.py CHANGED
@@ -86,14 +86,28 @@ class SmolLM3Model:
86
  model_config.max_position_embeddings = self.max_seq_length
87
 
88
  # Load model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  self.model = AutoModelForCausalLM.from_pretrained(
90
  self.model_name,
91
  config=model_config,
92
- torch_dtype=self.torch_dtype,
93
- device_map=self.device_map,
94
- trust_remote_code=True,
95
- use_flash_attention_2=self.config.use_flash_attention if self.config else True,
96
- use_cache=False # Disable KV cache for training
97
  )
98
 
99
  # Enable gradient checkpointing if specified
 
86
  model_config.max_position_embeddings = self.max_seq_length
87
 
88
  # Load model
89
+ model_kwargs = {
90
+ "torch_dtype": self.torch_dtype,
91
+ "device_map": self.device_map,
92
+ "trust_remote_code": True,
93
+ "use_cache": False # Disable KV cache for training
94
+ }
95
+
96
+ # Only add flash attention if the model supports it
97
+ if hasattr(self.config, 'use_flash_attention') and self.config.use_flash_attention:
98
+ try:
99
+ # Test if the model supports flash attention
100
+ test_config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=True)
101
+ if hasattr(test_config, 'use_flash_attention_2'):
102
+ model_kwargs["use_flash_attention_2"] = True
103
+ except:
104
+ # If flash attention is not supported, skip it
105
+ pass
106
+
107
  self.model = AutoModelForCausalLM.from_pretrained(
108
  self.model_name,
109
  config=model_config,
110
+ **model_kwargs
 
 
 
 
111
  )
112
 
113
  # Enable gradient checkpointing if specified