import torch import math import torch.nn.functional as F import torch.nn as nn class MyPrint(): def __init__(self, logger): self.logger = logger def pprint(self, *args): print(*args) log_message = ', '.join(str(arg) for arg in args) self.logger.info(log_message) class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler): def __init__(self, optimizer, total_steps, warmup_ratio=0.1, last_epoch=-1): self.total_steps = total_steps self.warmup_steps = total_steps * warmup_ratio super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch < self.warmup_steps: lr_scale = (self.last_epoch + 1) / self.warmup_steps return [base_lr * lr_scale for base_lr in self.base_lrs] else: cosine_step = self.last_epoch - self.warmup_steps cosine_steps = self.total_steps - self.warmup_steps return [ base_lr * (1 + math.cos(math.pi * cosine_step / cosine_steps)) / 2 for base_lr in self.base_lrs ] class FocalLoss(nn.Module): def __init__(self, gamma=2, alpha=0.25): super(FocalLoss, self).__init__() self.gamma = gamma # Focusing parameter self.alpha = alpha # Balance parameter def forward(self, probs, targets, eps=1e-7): target_probs = probs.gather(dim=-1, index=targets.unsqueeze(-1)) cross_entropy_loss = -torch.log(target_probs + eps) modulating_factor = (1 - target_probs) ** self.gamma focal_loss = self.alpha * modulating_factor * cross_entropy_loss return focal_loss.mean() class BatchTripletLoss(nn.Module): def __init__(self, margin=1.0, distance_metric="euclidean"): """ :param margin: Margin for triplet loss :param distance_metric: "euclidean" or "cosine" """ super(BatchTripletLoss, self).__init__() self.margin = margin self.distance_metric = distance_metric def _pairwise_distances(self, embeddings): """ Compute pairwise distances between embeddings in the batch. """ if self.distance_metric == "euclidean": # Squared pairwise Euclidean distances embeddings = F.normalize(embeddings, p=2, dim=1) dot_product = torch.matmul(embeddings, embeddings.t()) square_norm = torch.diag(dot_product) distances = square_norm.unsqueeze(1) - 2 * dot_product + square_norm.unsqueeze(0) distances = torch.clamp(distances, min=0.0) # Avoid negative distances return torch.sqrt(distances + 1e-12) # Add epsilon for numerical stability elif self.distance_metric == "cosine": # Cosine similarity -> distance normalized_embeddings = F.normalize(embeddings, p=2, dim=1) cosine_similarity = torch.matmul(normalized_embeddings, normalized_embeddings.t()) distances = 1 - cosine_similarity # Cosine distance return distances else: raise ValueError(f"Unknown distance metric: {self.distance_metric}") def forward(self, embeddings, labels): """ Compute the triplet loss for a batch of embeddings and their corresponding labels. :param embeddings: Tensor of shape [batch_size, embedding_dim] :param labels: Tensor of shape [batch_size], integer class labels """ # Compute pairwise distances distances = self._pairwise_distances(embeddings) # Mask for valid triplets (Anchor-Positive and Anchor-Negative pairs) labels = labels.unsqueeze(1) is_positive = labels.eq(labels.t()) # Positive mask is_negative = ~is_positive # Negative mask # For each anchor, find hardest positive and hardest negative anchor_positive_distances = distances * is_positive.float() # Mask positive distances hardest_positive_distances = anchor_positive_distances.max(dim=1)[0] anchor_negative_distances = distances + 1e6 * is_positive.float() # Mask negative distances hardest_negative_distances = anchor_negative_distances.min(dim=1)[0] # Compute Triplet Loss triplet_loss = F.relu(hardest_positive_distances - hardest_negative_distances + self.margin) return triplet_loss.mean() class CenterLoss(nn.Module): def __init__(self, num_classes, feat_dim, alpha=0.5): """ Center Loss 实现 :param num_classes: 类别数 :param feat_dim: 特征维度 :param alpha: 中心点更新的学习率 """ super(CenterLoss, self).__init__() self.num_classes = num_classes self.feat_dim = feat_dim self.alpha = alpha # 初始化类别中心点 [num_classes, feat_dim] self.centers = nn.Parameter(torch.randn(num_classes, feat_dim)) def forward(self, features, labels): """ 计算 Center Loss :param features: 输入特征 [batch_size, feat_dim] :param labels: 对应标签 [batch_size] :return: Center Loss """ batch_size = features.size(0) # 获取每个样本的类别中心点 [batch_size, feat_dim] centers_batch = self.centers[labels] # 计算 Center Loss loss = F.mse_loss(features, centers_batch, reduction="mean") # 手动更新中心点 self._update_centers(features, labels) return loss def _update_centers(self, features, labels): """ 更新类别中心点 :param features: 当前 batch 的样本特征 [batch_size, feat_dim] :param labels: 当前 batch 的样本标签 [batch_size] """ # 确保使用 torch.no_grad() 禁止梯度跟踪 with torch.no_grad(): unique_labels = labels.unique() # 当前 batch 中的类别 for label in unique_labels: mask = labels == label # 筛选出该类别的样本 selected_features = features[mask] # 属于该类别的特征 if selected_features.size(0) > 0: # 计算中心点增量 center_delta = selected_features.mean(dim=0) - self.centers[label] # 使用非 in-place 更新 self.centers[label] = self.centers[label] + self.alpha * center_delta class InterClassLoss(nn.Module): def __init__(self, margin=0.0001): super(InterClassLoss, self).__init__() self.margin = margin def forward(self, centers): # 计算每对中心点之间的距离 num_classes = centers.size(0) loss = 0 for i in range(num_classes): for j in range(i+1, num_classes): distance = torch.norm(centers[i] - centers[j]) loss += torch.max(torch.tensor(0.0), self.margin - distance) return loss