|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import config |
|
|
|
|
|
class ContrastiveLoss_simcse(nn.Module): |
|
"""SimCSE loss""" |
|
|
|
def __init__(self): |
|
super(ContrastiveLoss_simcse, self).__init__() |
|
self.temperature = config.temperature |
|
|
|
def forward(self, feature_vectors, labels): |
|
normalized_features = F.normalize( |
|
feature_vectors, p=2, dim=0 |
|
) |
|
|
|
|
|
anchor_indices = (labels == 0).nonzero().squeeze(dim=1) |
|
positive_indices = (labels == 1).nonzero().squeeze(dim=1) |
|
negative_indices = (labels == 2).nonzero().squeeze(dim=1) |
|
|
|
|
|
anchor = normalized_features[anchor_indices] |
|
positives = normalized_features[positive_indices] |
|
negatives = normalized_features[negative_indices] |
|
pos_and_neg = torch.cat([positives, negatives]) |
|
|
|
denominator = torch.sum( |
|
torch.exp( |
|
torch.div( |
|
torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)), |
|
self.temperature, |
|
) |
|
) |
|
) |
|
|
|
numerator = torch.exp( |
|
torch.div( |
|
torch.matmul(anchor, torch.transpose(positives, 0, 1)), |
|
self.temperature, |
|
) |
|
) |
|
|
|
loss = -torch.log( |
|
torch.div( |
|
numerator, |
|
denominator, |
|
) |
|
) |
|
|
|
return loss |
|
|
|
|
|
class ContrastiveLoss_simcse_w(nn.Module): |
|
"""SimCSE loss with weighting.""" |
|
|
|
def __init__(self): |
|
super(ContrastiveLoss_simcse_w, self).__init__() |
|
self.temperature = config.temperature |
|
|
|
def forward(self, feature_vectors, labels, scores): |
|
normalized_features = F.normalize( |
|
feature_vectors, p=2, dim=0 |
|
) |
|
|
|
|
|
anchor_indices = (labels == 0).nonzero().squeeze(dim=1) |
|
positive_indices = (labels == 1).nonzero().squeeze(dim=1) |
|
negative_indices = (labels == 2).nonzero().squeeze(dim=1) |
|
|
|
pos_scores = scores[positive_indices].float() |
|
normalized_neg_scores = F.normalize( |
|
scores[negative_indices].float(), p=2, dim=0 |
|
) |
|
normalized_neg_scores += 1 |
|
scores = torch.cat([pos_scores, normalized_neg_scores]) |
|
|
|
|
|
anchor = normalized_features[anchor_indices] |
|
positives = normalized_features[positive_indices] |
|
negatives = normalized_features[negative_indices] |
|
pos_and_neg = torch.cat([positives, negatives]) |
|
|
|
denominator = torch.sum( |
|
torch.exp( |
|
scores |
|
* torch.div( |
|
torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)), |
|
self.temperature, |
|
) |
|
) |
|
) |
|
|
|
numerator = torch.exp( |
|
torch.div( |
|
torch.matmul(anchor, torch.transpose(positives, 0, 1)), |
|
self.temperature, |
|
) |
|
) |
|
|
|
loss = -torch.log( |
|
torch.div( |
|
numerator, |
|
denominator, |
|
) |
|
) |
|
|
|
return loss |
|
|
|
|
|
class ContrastiveLoss_samp(nn.Module): |
|
"""Supervised contrastive loss without weighting.""" |
|
|
|
def __init__(self): |
|
super(ContrastiveLoss_samp, self).__init__() |
|
self.temperature = config.temperature |
|
|
|
def forward(self, feature_vectors, labels): |
|
|
|
normalized_features = F.normalize( |
|
feature_vectors, p=2, dim=0 |
|
) |
|
|
|
|
|
anchor_indices = (labels == 0).nonzero().squeeze(dim=1) |
|
positive_indices = (labels == 1).nonzero().squeeze(dim=1) |
|
negative_indices = (labels == 2).nonzero().squeeze(dim=1) |
|
|
|
|
|
anchor = normalized_features[anchor_indices] |
|
positives = normalized_features[positive_indices] |
|
negatives = normalized_features[negative_indices] |
|
pos_and_neg = torch.cat([positives, negatives]) |
|
|
|
pos_cardinal = positives.shape[0] |
|
|
|
denominator = torch.sum( |
|
torch.exp( |
|
torch.div( |
|
torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)), |
|
self.temperature, |
|
) |
|
) |
|
) |
|
|
|
sum_log_ent = torch.sum( |
|
torch.log( |
|
torch.div( |
|
torch.exp( |
|
torch.div( |
|
torch.matmul(anchor, torch.transpose(positives, 0, 1)), |
|
self.temperature, |
|
) |
|
), |
|
denominator, |
|
) |
|
) |
|
) |
|
|
|
scale = -1 / pos_cardinal |
|
|
|
return scale * sum_log_ent |
|
|
|
|
|
class ContrastiveLoss_samp_w(nn.Module): |
|
"""Supervised contrastive loss with weighting.""" |
|
|
|
def __init__(self): |
|
super(ContrastiveLoss_samp_w, self).__init__() |
|
self.temperature = config.temperature |
|
|
|
def forward(self, feature_vectors, labels, scores): |
|
|
|
normalized_features = F.normalize( |
|
feature_vectors, p=2, dim=0 |
|
) |
|
|
|
|
|
anchor_indices = (labels == 0).nonzero().squeeze(dim=1) |
|
positive_indices = (labels == 1).nonzero().squeeze(dim=1) |
|
negative_indices = (labels == 2).nonzero().squeeze(dim=1) |
|
|
|
|
|
num_skip = len(positive_indices) + 1 |
|
pos_scores = scores[: (num_skip - 1)].float() |
|
normalized_neg_scores = F.normalize( |
|
scores[num_skip:].float(), p=2, dim=0 |
|
) |
|
normalized_neg_scores += 1 |
|
scores = torch.cat([pos_scores, normalized_neg_scores]) |
|
|
|
|
|
anchor = normalized_features[anchor_indices] |
|
positives = normalized_features[positive_indices] |
|
negatives = normalized_features[negative_indices] |
|
pos_and_neg = torch.cat([positives, negatives]) |
|
|
|
pos_cardinal = positives.shape[0] |
|
|
|
denominator = torch.sum( |
|
torch.exp( |
|
scores |
|
* torch.div( |
|
torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)), |
|
self.temperature, |
|
) |
|
) |
|
) |
|
|
|
sum_log_ent = torch.sum( |
|
torch.log( |
|
torch.div( |
|
torch.exp( |
|
torch.div( |
|
torch.matmul(anchor, torch.transpose(positives, 0, 1)), |
|
self.temperature, |
|
) |
|
), |
|
denominator, |
|
) |
|
) |
|
) |
|
|
|
scale = -1 / pos_cardinal |
|
|
|
return scale * sum_log_ent |
|
|