File size: 2,280 Bytes
7f4f2d3 d09e211 7f4f2d3 d09e211 7f4f2d3 d09e211 7f4f2d3 d09e211 7f4f2d3 d09e211 7f4f2d3 d09e211 7f4f2d3 d09e211 7f4f2d3 d09e211 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import config
class CL_loss(nn.Module):
"""Supervised contrastive loss without weighting."""
def __init__(self):
super(CL_loss, self).__init__()
self.temperature = config.temperature
def forward(self, feature_vectors, labels):
normalized_features = F.normalize(
feature_vectors, p=2, dim=1
) # normalize by row, each row euc is approximately 1
# Identify indices for each label
anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
positive_indices = (labels == 1).nonzero().squeeze(dim=1)
negative_indices = (labels == 2).nonzero().squeeze(dim=1)
# Extract tensors based on labels
anchor = normalized_features[anchor_indices]
positives = normalized_features[positive_indices]
negatives = normalized_features[negative_indices]
pos_and_neg = torch.cat([positives, negatives])
pos_cardinal = positives.shape[0]
denominator = torch.sum(
torch.exp(
torch.div(
F.cosine_similarity(anchor, pos_and_neg, dim=1),
self.temperature,
)
)
)
# if not torch.isfinite(denominator):
# print("Denominator is Inf!")
# if not torch.isfinite(
# torch.exp(
# torch.div(F.cosine_similarity(anchor, pos_and_neg, dim=1)),
# self.temperature,
# )
# ).all():
# print("Exp is Inf!")
# print(
# torch.exp(
# torch.div(F.cosine_similarity(anchor, pos_and_neg, dim=1)),
# self.temperature,
# )
# )
sum_log_ent = torch.sum(
torch.log(
torch.div(
torch.exp(
torch.div(
F.cosine_similarity(anchor, positives, dim=1),
self.temperature,
)
),
denominator,
)
)
)
scale = -1 / pos_cardinal
out = scale * sum_log_ent
return out
|