|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import config |
|
|
|
|
|
class CL_loss(nn.Module): |
|
"""Supervised contrastive loss without weighting.""" |
|
|
|
def __init__(self): |
|
super(CL_loss, self).__init__() |
|
self.temperature = config.temperature |
|
|
|
def forward(self, feature_vectors, labels): |
|
normalized_features = F.normalize( |
|
feature_vectors, p=2, dim=1 |
|
) |
|
|
|
|
|
anchor_indices = (labels == 0).nonzero().squeeze(dim=1) |
|
positive_indices = (labels == 1).nonzero().squeeze(dim=1) |
|
negative_indices = (labels == 2).nonzero().squeeze(dim=1) |
|
|
|
|
|
anchor = normalized_features[anchor_indices] |
|
positives = normalized_features[positive_indices] |
|
negatives = normalized_features[negative_indices] |
|
pos_and_neg = torch.cat([positives, negatives]) |
|
|
|
pos_cardinal = positives.shape[0] |
|
|
|
denominator = torch.sum( |
|
torch.exp( |
|
torch.div( |
|
F.cosine_similarity(anchor, pos_and_neg, dim=1), |
|
self.temperature, |
|
) |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sum_log_ent = torch.sum( |
|
torch.log( |
|
torch.div( |
|
torch.exp( |
|
torch.div( |
|
F.cosine_similarity(anchor, positives, dim=1), |
|
self.temperature, |
|
) |
|
), |
|
denominator, |
|
) |
|
) |
|
) |
|
|
|
scale = -1 / pos_cardinal |
|
out = scale * sum_log_ent |
|
|
|
return out |
|
|