import torch import torch.nn as nn import torch.nn.functional as F class CrossEntropyLoss(nn.Module): """ Hard Negative NCE loss for contrastive learning. """ def __init__(self, **kwargs): super(CrossEntropyLoss, self).__init__() def forward(self, tar_img_feat: torch.Tensor, query_feat: torch.Tensor, temp): device = tar_img_feat.device sim_t2q = tar_img_feat @ query_feat.T / temp sim_q2t = query_feat @ tar_img_feat.T / temp bs = sim_t2q.size(0) loss_t2q = F.cross_entropy(sim_t2q, torch.arange(bs, device=device)) loss_q2t = F.cross_entropy(sim_q2t, torch.arange(bs, device=device)) return (loss_t2q + loss_q2t) / 2 class HardNegativeNCE(nn.Module): """ Hard-Negative NCE loss for contrastive learning. https://arxiv.org/pdf/2301.02280.pdf """ def __init__(self, alpha: float = 1.0, beta: float = 0.0, **kwargs): """ Args: alpha: rescaling factor for positiver terms beta: concentration parameter Note: alpha = 1 and beta = 0 corresponds to the original Info-NCE loss """ super(HardNegativeNCE, self).__init__() self.alpha = alpha self.beta = beta def forward( self, video_embds: torch.Tensor, text_embds: torch.Tensor, temp, ): """ Args: video_embds: (batch_size, video_embd_dim) text_embds: (batch_size, text_embd_dim) """ batch_size = video_embds.size(0) # computation of the similarity matrix sim_matrix = video_embds @ text_embds.T # (batch_size, batch_size) # scale the similarity matrix with the temperature sim_matrix = sim_matrix / temp sim_matrix = sim_matrix.float() nominator = torch.diagonal(sim_matrix) beta_sim = self.beta * sim_matrix w_v2t = ( (batch_size - 1) * torch.exp(beta_sim) / (torch.exp(beta_sim).sum(dim=1) - torch.exp(torch.diagonal(beta_sim))) ) w_t2v = ( (batch_size - 1) * torch.exp(beta_sim) / (torch.exp(beta_sim).sum(dim=0) - torch.exp(torch.diagonal(beta_sim))) ) # replace the diagonal terms of w_v2t and w_t2v with alpha w_v2t[range(batch_size), range(batch_size)] = self.alpha w_t2v[range(batch_size), range(batch_size)] = self.alpha denominator_v2t = torch.log((torch.exp(sim_matrix) * w_v2t).sum(dim=1)) denominator_t2v = torch.log((torch.exp(sim_matrix) * w_t2v).sum(dim=0)) hn_nce_loss = (denominator_v2t - nominator).mean() + ( denominator_t2v - nominator ).mean() return hn_nce_loss