Spaces:
Runtime error
Runtime error
File size: 1,272 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch.nn as nn
from transformers import ElectraPreTrainedModel, ElectraModel, AutoTokenizer
class ElectraReranker(ElectraPreTrainedModel):
"""
Shallow wrapper around HuggingFace transformers. All new parameters should be defined at this level.
This makes sure `{from,save}_pretrained` and `init_weights` are applied to new parameters correctly.
"""
_keys_to_ignore_on_load_unexpected = [r"cls"]
def __init__(self, config):
super().__init__(config)
self.electra = ElectraModel(config)
self.linear = nn.Linear(config.hidden_size, 1)
self.raw_tokenizer = AutoTokenizer.from_pretrained('google/electra-large-discriminator')
self.init_weights()
def forward(self, encoding):
outputs = self.electra(encoding.input_ids,
attention_mask=encoding.attention_mask,
token_type_ids=encoding.token_type_ids)[0]
scores = self.linear(outputs[:, 0]).squeeze(-1)
return scores
def save(self, path):
assert not path.endswith('.dnn'), f"{path}: We reserve *.dnn names for the deprecated checkpoint format."
self.save_pretrained(path)
self.raw_tokenizer.save_pretrained(path) |