File size: 1,018 Bytes
07ea3b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
from transformers import XLMRobertaModel as XLMRobertaModelBase
class XLMRobertaModel(XLMRobertaModelBase):
def __init__(self, config):
super().__init__(config)
self.question_projection = torch.nn.Linear(768, 512)
self.answer_projection = torch.nn.Linear(768, 512)
def _embed(self, input_ids, attention_mask, projection):
outputs = super().__call__(input_ids, attention_mask=attention_mask)
sequence_output = outputs[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(sequence_output.size()).float()
embeddings = torch.sum(sequence_output * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return torch.tanh(projection(embeddings))
def question(self, input_ids, attention_mask):
return self._embed(input_ids, attention_mask, self.question_projection)
def answer(self, input_ids, attention_mask):
return self._embed(input_ids, attention_mask, self.answer_projection) |