File size: 7,167 Bytes
224a33f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import torch
import math
import torch.nn.functional as F
import torch.nn as nn


class MyPrint():
    def __init__(self, logger):
        self.logger = logger
        
    def pprint(self, *args):
        print(*args)
        log_message = ', '.join(str(arg) for arg in args)
        self.logger.info(log_message)


class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, total_steps, warmup_ratio=0.1, last_epoch=-1):
        self.total_steps = total_steps
        self.warmup_steps = total_steps * warmup_ratio
        super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            lr_scale = (self.last_epoch + 1) / self.warmup_steps
            return [base_lr * lr_scale for base_lr in self.base_lrs]
        else:
            cosine_step = self.last_epoch - self.warmup_steps
            cosine_steps = self.total_steps - self.warmup_steps
            return [
                base_lr * (1 + math.cos(math.pi * cosine_step / cosine_steps)) / 2
                for base_lr in self.base_lrs
            ]


class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma  # Focusing parameter
        self.alpha = alpha  # Balance parameter

    def forward(self, probs, targets, eps=1e-7):
        target_probs = probs.gather(dim=-1, index=targets.unsqueeze(-1))
        cross_entropy_loss = -torch.log(target_probs + eps)
        modulating_factor = (1 - target_probs) ** self.gamma
        focal_loss = self.alpha * modulating_factor * cross_entropy_loss
        return focal_loss.mean()


class BatchTripletLoss(nn.Module):
    def __init__(self, margin=1.0, distance_metric="euclidean"):
        """

        :param margin: Margin for triplet loss

        :param distance_metric: "euclidean" or "cosine"

        """
        super(BatchTripletLoss, self).__init__()
        self.margin = margin
        self.distance_metric = distance_metric

    def _pairwise_distances(self, embeddings):
        """

        Compute pairwise distances between embeddings in the batch.

        """
        if self.distance_metric == "euclidean":
            # Squared pairwise Euclidean distances
            embeddings = F.normalize(embeddings, p=2, dim=1)
            dot_product = torch.matmul(embeddings, embeddings.t())
            square_norm = torch.diag(dot_product)
            distances = square_norm.unsqueeze(1) - 2 * dot_product + square_norm.unsqueeze(0)
            distances = torch.clamp(distances, min=0.0)  # Avoid negative distances
            return torch.sqrt(distances + 1e-12)  # Add epsilon for numerical stability
        elif self.distance_metric == "cosine":
            # Cosine similarity -> distance
            normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
            cosine_similarity = torch.matmul(normalized_embeddings, normalized_embeddings.t())
            distances = 1 - cosine_similarity  # Cosine distance
            return distances
        else:
            raise ValueError(f"Unknown distance metric: {self.distance_metric}")

    def forward(self, embeddings, labels):
        """

        Compute the triplet loss for a batch of embeddings and their corresponding labels.

        :param embeddings: Tensor of shape [batch_size, embedding_dim]

        :param labels: Tensor of shape [batch_size], integer class labels

        """
        # Compute pairwise distances
        distances = self._pairwise_distances(embeddings)

        # Mask for valid triplets (Anchor-Positive and Anchor-Negative pairs)
        labels = labels.unsqueeze(1)
        is_positive = labels.eq(labels.t())  # Positive mask
        is_negative = ~is_positive  # Negative mask

        # For each anchor, find hardest positive and hardest negative
        anchor_positive_distances = distances * is_positive.float()  # Mask positive distances
        hardest_positive_distances = anchor_positive_distances.max(dim=1)[0]

        anchor_negative_distances = distances + 1e6 * is_positive.float()  # Mask negative distances
        hardest_negative_distances = anchor_negative_distances.min(dim=1)[0]

        # Compute Triplet Loss
        triplet_loss = F.relu(hardest_positive_distances - hardest_negative_distances + self.margin)
        return triplet_loss.mean()


class CenterLoss(nn.Module):
    def __init__(self, num_classes, feat_dim, alpha=0.5):
        """

        Center Loss 实现

        :param num_classes: 类别数

        :param feat_dim: 特征维度

        :param alpha: 中心点更新的学习率

        """
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.alpha = alpha

        # 初始化类别中心点 [num_classes, feat_dim]
        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))

    def forward(self, features, labels):
        """

        计算 Center Loss

        :param features: 输入特征 [batch_size, feat_dim]

        :param labels: 对应标签 [batch_size]

        :return: Center Loss

        """
        batch_size = features.size(0)
        
        # 获取每个样本的类别中心点 [batch_size, feat_dim]
        centers_batch = self.centers[labels]

        # 计算 Center Loss
        loss = F.mse_loss(features, centers_batch, reduction="mean")

        # 手动更新中心点
        self._update_centers(features, labels)

        return loss

    def _update_centers(self, features, labels):
        """

        更新类别中心点

        :param features: 当前 batch 的样本特征 [batch_size, feat_dim]

        :param labels: 当前 batch 的样本标签 [batch_size]

        """
        # 确保使用 torch.no_grad() 禁止梯度跟踪
        with torch.no_grad():
            unique_labels = labels.unique()  # 当前 batch 中的类别
            for label in unique_labels:
                mask = labels == label  # 筛选出该类别的样本
                selected_features = features[mask]  # 属于该类别的特征
                if selected_features.size(0) > 0:
                    # 计算中心点增量
                    center_delta = selected_features.mean(dim=0) - self.centers[label]
                    # 使用非 in-place 更新
                    self.centers[label] = self.centers[label] + self.alpha * center_delta


class InterClassLoss(nn.Module):
    def __init__(self, margin=0.0001):
        super(InterClassLoss, self).__init__()
        self.margin = margin

    def forward(self, centers):
        # 计算每对中心点之间的距离
        num_classes = centers.size(0)
        loss = 0
        for i in range(num_classes):
            for j in range(i+1, num_classes):
                distance = torch.norm(centers[i] - centers[j])
                loss += torch.max(torch.tensor(0.0), self.margin - distance)
        return loss