davda54 commited on
Commit
0c90c6d
1 Parent(s): 16a504c

Update to new HF version

Browse files
Files changed (1) hide show
  1. modeling_nort5.py +18 -1
modeling_nort5.py CHANGED
@@ -387,7 +387,24 @@ class NorT5Model(NorT5PreTrainedModel):
387
  self.embedding.word_embedding = value
388
 
389
  def get_encoder(self):
390
- return self.get_encoder_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
  def get_decoder(self):
393
  return self.get_decoder_output
 
387
  self.embedding.word_embedding = value
388
 
389
  def get_encoder(self):
390
+ class EncoderWrapper:
391
+ def __call__(cls, *args, **kwargs):
392
+ return cls.forward(*args, **kwargs)
393
+
394
+ def forward(
395
+ cls,
396
+ input_ids: Optional[torch.Tensor] = None,
397
+ attention_mask: Optional[torch.Tensor] = None,
398
+ output_hidden_states: Optional[bool] = None,
399
+ output_attentions: Optional[bool] = None,
400
+ return_dict: Optional[bool] = None,
401
+ ):
402
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
403
+
404
+ return self.get_encoder_output(
405
+ input_ids, attention_mask, output_hidden_states, output_attentions, return_dict=return_dict
406
+ )
407
+ return EncoderWrapper()
408
 
409
  def get_decoder(self):
410
  return self.get_decoder_output