NohTow commited on
Commit
f1c47e8
·
1 Parent(s): 29f5554

Return BaseModelOutput

Browse files
Files changed (1) hide show
  1. modeling_flexbert.py +4 -3
modeling_flexbert.py CHANGED
@@ -66,7 +66,7 @@ from transformers.modeling_outputs import (
66
  SequenceClassifierOutput,
67
  )
68
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
69
-
70
  from .bert_padding import index_put_first_axis
71
 
72
  from .activation import get_act_fn
@@ -874,7 +874,7 @@ class FlexBertPreTrainedModel(BertPreTrainedModel):
874
  init_weights(self.config, module, type_of_module=ModuleType.emb)
875
  else:
876
  print("Custom weight init for the given module is not supported")
877
- print(module)
878
  # raise NotImplementedError("Custom weight init for the given module is not supported")
879
 
880
 
@@ -967,7 +967,8 @@ class FlexBertModel(FlexBertPreTrainedModel):
967
 
968
  if self.final_norm is not None:
969
  encoder_outputs = self.final_norm(encoder_outputs)
970
- return encoder_outputs
 
971
 
972
  def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
973
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
 
66
  SequenceClassifierOutput,
67
  )
68
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
69
+ from transformers import BaseModelOutput
70
  from .bert_padding import index_put_first_axis
71
 
72
  from .activation import get_act_fn
 
874
  init_weights(self.config, module, type_of_module=ModuleType.emb)
875
  else:
876
  print("Custom weight init for the given module is not supported")
877
+ # print(module)
878
  # raise NotImplementedError("Custom weight init for the given module is not supported")
879
 
880
 
 
967
 
968
  if self.final_norm is not None:
969
  encoder_outputs = self.final_norm(encoder_outputs)
970
+
971
+ return BaseModelOutput(last_hidden_state=encoder_outputs)
972
 
973
  def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
974
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"