|
from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func |
|
from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward |
|
from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription |
|
|
|
|
|
class T5EncoderPolicy(Policy): |
|
def config_sanity_check(self): |
|
assert not self.shard_config.enable_tensor_parallelism |
|
assert not self.shard_config.enable_flash_attention |
|
|
|
def preprocess(self): |
|
return self.model |
|
|
|
def module_policy(self): |
|
from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack |
|
|
|
policy = {} |
|
|
|
|
|
try: |
|
from opensora.acceleration.shardformer.modeling.t5 import T5LayerNorm |
|
|
|
|
|
self.append_or_create_submodule_replacement( |
|
description=SubModuleReplacementDescription( |
|
suffix="layer_norm", |
|
target_module=T5LayerNorm, |
|
), |
|
policy=policy, |
|
target_key=T5LayerFF, |
|
) |
|
self.append_or_create_submodule_replacement( |
|
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm), |
|
policy=policy, |
|
target_key=T5LayerSelfAttention, |
|
) |
|
self.append_or_create_submodule_replacement( |
|
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm), |
|
policy=policy, |
|
target_key=T5Stack, |
|
) |
|
except (ImportError, ModuleNotFoundError): |
|
pass |
|
|
|
|
|
if self.shard_config.enable_jit_fused: |
|
self.append_or_create_method_replacement( |
|
description={ |
|
"forward": get_jit_fused_T5_layer_ff_forward(), |
|
"dropout_add": get_jit_fused_dropout_add_func(), |
|
}, |
|
policy=policy, |
|
target_key=T5LayerFF, |
|
) |
|
self.append_or_create_method_replacement( |
|
description={ |
|
"forward": get_T5_layer_self_attention_forward(), |
|
"dropout_add": get_jit_fused_dropout_add_func(), |
|
}, |
|
policy=policy, |
|
target_key=T5LayerSelfAttention, |
|
) |
|
|
|
return policy |
|
|
|
def postprocess(self): |
|
return self.model |
|
|