Return BaseModelOutput
Browse files- modeling_flexbert.py +4 -3
modeling_flexbert.py
CHANGED
@@ -66,7 +66,7 @@ from transformers.modeling_outputs import (
|
|
66 |
SequenceClassifierOutput,
|
67 |
)
|
68 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
69 |
-
|
70 |
from .bert_padding import index_put_first_axis
|
71 |
|
72 |
from .activation import get_act_fn
|
@@ -874,7 +874,7 @@ class FlexBertPreTrainedModel(BertPreTrainedModel):
|
|
874 |
init_weights(self.config, module, type_of_module=ModuleType.emb)
|
875 |
else:
|
876 |
print("Custom weight init for the given module is not supported")
|
877 |
-
print(module)
|
878 |
# raise NotImplementedError("Custom weight init for the given module is not supported")
|
879 |
|
880 |
|
@@ -967,7 +967,8 @@ class FlexBertModel(FlexBertPreTrainedModel):
|
|
967 |
|
968 |
if self.final_norm is not None:
|
969 |
encoder_outputs = self.final_norm(encoder_outputs)
|
970 |
-
|
|
|
971 |
|
972 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
973 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
|
|
66 |
SequenceClassifierOutput,
|
67 |
)
|
68 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
69 |
+
from transformers import BaseModelOutput
|
70 |
from .bert_padding import index_put_first_axis
|
71 |
|
72 |
from .activation import get_act_fn
|
|
|
874 |
init_weights(self.config, module, type_of_module=ModuleType.emb)
|
875 |
else:
|
876 |
print("Custom weight init for the given module is not supported")
|
877 |
+
# print(module)
|
878 |
# raise NotImplementedError("Custom weight init for the given module is not supported")
|
879 |
|
880 |
|
|
|
967 |
|
968 |
if self.final_norm is not None:
|
969 |
encoder_outputs = self.final_norm(encoder_outputs)
|
970 |
+
|
971 |
+
return BaseModelOutput(last_hidden_state=encoder_outputs)
|
972 |
|
973 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
974 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|