anas-awadalla
commited on
Commit
•
f0a13e4
1
Parent(s):
38fdb61
turn attention_mask to bool in forward pass
Browse files- mosaic_gpt.py +1 -0
mosaic_gpt.py
CHANGED
@@ -247,6 +247,7 @@ class MosaicGPT(PreTrainedModel):
|
|
247 |
use_cache: Optional[bool] = None):
|
248 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
249 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
250 |
|
251 |
# These args are passed in by keyword in huggingface's generate function
|
252 |
# https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
|
|
|
247 |
use_cache: Optional[bool] = None):
|
248 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
249 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
250 |
+
attention_mask = attention_mask.bool()
|
251 |
|
252 |
# These args are passed in by keyword in huggingface's generate function
|
253 |
# https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
|