|
|
|
from transformers import MistralModel, MistralConfig |
|
from typing import Dict |
|
from transformers.file_utils import ModelOutput |
|
from typing import List, Optional, Tuple, Union |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
from torch import nn, Tensor |
|
from dataclasses import dataclass |
|
from torch import nn |
|
from typing import Dict |
|
import torch |
|
from transformers.file_utils import ModelOutput |
|
import torch.nn.functional as F |
|
|
|
COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y) |
|
|
|
@dataclass |
|
class EncoderOutput(ModelOutput): |
|
loss: Optional[Tensor] = None |
|
|
|
class MistralModelEmbedding(MistralModel): |
|
def __init__(self, config: MistralConfig, **kwargs): |
|
super().__init__(config, **kwargs) |
|
|
|
self.dense_layer = nn.Linear(self.config.hidden_size,768) |
|
|
|
def sentence_embedding(self, hidden_state, mask): |
|
if self.config.sentence_pooling_method == 'mean': |
|
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1) |
|
d = mask.sum(axis=1, keepdim=True).float() |
|
return s / d |
|
elif self.config.sentence_pooling_method == 'cls': |
|
return hidden_state[:,0] |
|
|
|
def encode(self, features): |
|
if features is None: |
|
return None |
|
psg_out = super().forward(**features,return_dict=True) |
|
output = self.dense_layer(psg_out.last_hidden_state) |
|
p_reps = self.sentence_embedding(output, features['attention_mask']) |
|
if self.config.normalized: |
|
p_reps = torch.nn.functional.normalize(p_reps, dim=-1) |
|
return p_reps.contiguous() |
|
|
|
|
|
def forward(self, query: Dict[str, Tensor] = None, |
|
passage: Dict[str, Tensor] = None, labels = None, margin = 0.5): |
|
q_reps = self.encode(query) |
|
p_reps = self.encode(passage) |
|
|
|
loss = None |
|
if labels is not None: |
|
distances = COSINE_DISTANCE(q_reps, p_reps) |
|
losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(margin - distances).pow(2)) |
|
loss = losses.mean() |
|
|
|
return EncoderOutput( |
|
loss=loss, |
|
) |
|
|
|
|