jupyterjazz
commited on
Commit
•
6787a0f
1
Parent(s):
1b0fa28
fix: set fp32 when using cpu bc bf16 is slow
Browse files
configuration_xlm_roberta.py
CHANGED
@@ -126,3 +126,5 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
126 |
self.torch_dtype = getattr(torch, torch_dtype)
|
127 |
else:
|
128 |
self.torch_dtype = torch_dtype
|
|
|
|
|
|
126 |
self.torch_dtype = getattr(torch, torch_dtype)
|
127 |
else:
|
128 |
self.torch_dtype = torch_dtype
|
129 |
+
if not self.use_flash_attn or not torch.cuda.is_available():
|
130 |
+
self.torch_dtype = torch.float32
|