anas-awadalla
commited on
Commit
•
bfa38d4
1
Parent(s):
8bc4eba
Update mosaic_gpt.py
Browse files- mosaic_gpt.py +2 -1
mosaic_gpt.py
CHANGED
@@ -247,7 +247,8 @@ class MosaicGPT(PreTrainedModel):
|
|
247 |
use_cache: Optional[bool] = None):
|
248 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
249 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
250 |
-
|
|
|
251 |
# These args are passed in by keyword in huggingface's generate function
|
252 |
# https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
|
253 |
# but have not yet been fully implemented in MosaicGPT
|
|
|
247 |
use_cache: Optional[bool] = None):
|
248 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
249 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
250 |
+
attention_mask = attention_mask.bool()
|
251 |
+
|
252 |
# These args are passed in by keyword in huggingface's generate function
|
253 |
# https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
|
254 |
# but have not yet been fully implemented in MosaicGPT
|