Alex Birch
commited on
apply gradient checkpointing to Attention blocks
Browse files- modeling_mpt.py +2 -2
modeling_mpt.py
CHANGED
@@ -12,7 +12,7 @@ from torch.utils.checkpoint import checkpoint
|
|
12 |
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
13 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
14 |
from transformers.utils import logging
|
15 |
-
from .attention import attn_bias_shape, build_attn_bias, PastKeyValue
|
16 |
from .blocks import MPTBlock, MPTBlockOutput
|
17 |
from .norm import NORM_CLASS_REGISTRY
|
18 |
from .configuration_mpt import MPTConfig
|
@@ -41,7 +41,7 @@ class MPTPreTrainedModel(PreTrainedModel):
|
|
41 |
_no_split_modules = ['MPTBlock']
|
42 |
supports_gradient_checkpointing = True
|
43 |
def _set_gradient_checkpointing(self, module: nn.Module, value=False) -> None:
|
44 |
-
if isinstance(module, MPTModel):
|
45 |
module.gradient_checkpointing = value
|
46 |
|
47 |
class MPTModel(MPTPreTrainedModel):
|
|
|
12 |
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
13 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
14 |
from transformers.utils import logging
|
15 |
+
from .attention import attn_bias_shape, build_attn_bias, PastKeyValue, MultiheadAttention, MultiQueryAttention
|
16 |
from .blocks import MPTBlock, MPTBlockOutput
|
17 |
from .norm import NORM_CLASS_REGISTRY
|
18 |
from .configuration_mpt import MPTConfig
|
|
|
41 |
_no_split_modules = ['MPTBlock']
|
42 |
supports_gradient_checkpointing = True
|
43 |
def _set_gradient_checkpointing(self, module: nn.Module, value=False) -> None:
|
44 |
+
if isinstance(module, MPTModel) or isinstance(module, MultiheadAttention) or isinstance(module, MultiQueryAttention):
|
45 |
module.gradient_checkpointing = value
|
46 |
|
47 |
class MPTModel(MPTPreTrainedModel):
|