import torch
import torch.nn as nn
import torch.nn.functional as F
import config


class ContrastiveLoss_simcse(nn.Module):
    """SimCSE loss"""

    def __init__(self):
        super(ContrastiveLoss_simcse, self).__init__()
        self.temperature = config.temperature

    def forward(self, feature_vectors, labels):
        normalized_features = F.normalize(
            feature_vectors, p=2, dim=0
        )  # normalize along columns

        # Identify indices for each label
        anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
        positive_indices = (labels == 1).nonzero().squeeze(dim=1)
        negative_indices = (labels == 2).nonzero().squeeze(dim=1)

        # Extract tensors based on labels
        anchor = normalized_features[anchor_indices]
        positives = normalized_features[positive_indices]
        negatives = normalized_features[negative_indices]
        pos_and_neg = torch.cat([positives, negatives])

        denominator = torch.sum(
            torch.exp(
                torch.div(
                    torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
                    self.temperature,
                )
            )
        )

        numerator = torch.exp(
            torch.div(
                torch.matmul(anchor, torch.transpose(positives, 0, 1)),
                self.temperature,
            )
        )

        loss = -torch.log(
            torch.div(
                numerator,
                denominator,
            )
        )

        return loss


class ContrastiveLoss_simcse_w(nn.Module):
    """SimCSE loss with weighting."""

    def __init__(self):
        super(ContrastiveLoss_simcse_w, self).__init__()
        self.temperature = config.temperature

    def forward(self, feature_vectors, labels, scores):
        normalized_features = F.normalize(
            feature_vectors, p=2, dim=0
        )  # normalize along columns

        # Identify indices for each label
        anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
        positive_indices = (labels == 1).nonzero().squeeze(dim=1)
        negative_indices = (labels == 2).nonzero().squeeze(dim=1)

        pos_scores = scores[positive_indices].float()
        normalized_neg_scores = F.normalize(
            scores[negative_indices].float(), p=2, dim=0
        )  # l2-norm
        normalized_neg_scores += 1
        scores = torch.cat([pos_scores, normalized_neg_scores])

        # Extract tensors based on labels
        anchor = normalized_features[anchor_indices]
        positives = normalized_features[positive_indices]
        negatives = normalized_features[negative_indices]
        pos_and_neg = torch.cat([positives, negatives])

        denominator = torch.sum(
            torch.exp(
                scores
                * torch.div(
                    torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
                    self.temperature,
                )
            )
        )

        numerator = torch.exp(
            torch.div(
                torch.matmul(anchor, torch.transpose(positives, 0, 1)),
                self.temperature,
            )
        )

        loss = -torch.log(
            torch.div(
                numerator,
                denominator,
            )
        )

        return loss


class ContrastiveLoss_samp(nn.Module):
    """Supervised contrastive loss without weighting."""

    def __init__(self):
        super(ContrastiveLoss_samp, self).__init__()
        self.temperature = config.temperature

    def forward(self, feature_vectors, labels):
        # Normalize feature vectors
        normalized_features = F.normalize(
            feature_vectors, p=2, dim=0
        )  # normalize along columns

        # Identify indices for each label
        anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
        positive_indices = (labels == 1).nonzero().squeeze(dim=1)
        negative_indices = (labels == 2).nonzero().squeeze(dim=1)

        # Extract tensors based on labels
        anchor = normalized_features[anchor_indices]
        positives = normalized_features[positive_indices]
        negatives = normalized_features[negative_indices]
        pos_and_neg = torch.cat([positives, negatives])

        pos_cardinal = positives.shape[0]

        denominator = torch.sum(
            torch.exp(
                torch.div(
                    torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
                    self.temperature,
                )
            )
        )

        sum_log_ent = torch.sum(
            torch.log(
                torch.div(
                    torch.exp(
                        torch.div(
                            torch.matmul(anchor, torch.transpose(positives, 0, 1)),
                            self.temperature,
                        )
                    ),
                    denominator,
                )
            )
        )

        scale = -1 / pos_cardinal

        return scale * sum_log_ent


class ContrastiveLoss_samp_w(nn.Module):
    """Supervised contrastive loss with weighting."""

    def __init__(self):
        super(ContrastiveLoss_samp_w, self).__init__()
        self.temperature = config.temperature

    def forward(self, feature_vectors, labels, scores):
        # Normalize feature vectors
        normalized_features = F.normalize(
            feature_vectors, p=2, dim=0
        )  # normalize along columns

        # Identify indices for each label
        anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
        positive_indices = (labels == 1).nonzero().squeeze(dim=1)
        negative_indices = (labels == 2).nonzero().squeeze(dim=1)

        # Normalize score vector
        num_skip = len(positive_indices) + 1
        pos_scores = scores[: (num_skip - 1)].float()  # exclude anchor
        normalized_neg_scores = F.normalize(
            scores[num_skip:].float(), p=2, dim=0
        )  # l2-norm
        normalized_neg_scores += 1
        scores = torch.cat([pos_scores, normalized_neg_scores])

        # Extract tensors based on labels
        anchor = normalized_features[anchor_indices]
        positives = normalized_features[positive_indices]
        negatives = normalized_features[negative_indices]
        pos_and_neg = torch.cat([positives, negatives])

        pos_cardinal = positives.shape[0]

        denominator = torch.sum(
            torch.exp(
                scores
                * torch.div(
                    torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
                    self.temperature,
                )
            )
        )

        sum_log_ent = torch.sum(
            torch.log(
                torch.div(
                    torch.exp(
                        torch.div(
                            torch.matmul(anchor, torch.transpose(positives, 0, 1)),
                            self.temperature,
                        )
                    ),
                    denominator,
                )
            )
        )

        scale = -1 / pos_cardinal

        return scale * sum_log_ent