Update mosaic_gpt.py
#2
by
i-gao
- opened
- mosaic_gpt.py +1 -1
mosaic_gpt.py
CHANGED
@@ -247,7 +247,7 @@ class MosaicGPT(PreTrainedModel):
|
|
247 |
use_cache: Optional[bool] = None):
|
248 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
249 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
250 |
-
attention_mask = attention_mask.bool()
|
251 |
|
252 |
# These args are passed in by keyword in huggingface's generate function
|
253 |
# https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
|
|
|
247 |
use_cache: Optional[bool] = None):
|
248 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
249 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
250 |
+
attention_mask = attention_mask.bool() if attention_mask is not None else None
|
251 |
|
252 |
# These args are passed in by keyword in huggingface's generate function
|
253 |
# https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
|