File size: 7,167 Bytes
224a33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import torch
import math
import torch.nn.functional as F
import torch.nn as nn
class MyPrint():
def __init__(self, logger):
self.logger = logger
def pprint(self, *args):
log_message = ', '.join(str(arg) for arg in args)
class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, total_steps, warmup_ratio=0.1, last_epoch=-1):
self.total_steps = total_steps
self.warmup_steps = total_steps * warmup_ratio
super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
lr_scale = (self.last_epoch + 1) / self.warmup_steps
return [base_lr * lr_scale for base_lr in self.base_lrs]
cosine_step = self.last_epoch - self.warmup_steps
cosine_steps = self.total_steps - self.warmup_steps
return [
base_lr * (1 + math.cos(math.pi * cosine_step / cosine_steps)) / 2
for base_lr in self.base_lrs
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.25):
super(FocalLoss, self).__init__()
self.gamma = gamma # Focusing parameter
self.alpha = alpha # Balance parameter
def forward(self, probs, targets, eps=1e-7):
target_probs = probs.gather(dim=-1, index=targets.unsqueeze(-1))
cross_entropy_loss = -torch.log(target_probs + eps)
modulating_factor = (1 - target_probs) ** self.gamma
focal_loss = self.alpha * modulating_factor * cross_entropy_loss
return focal_loss.mean()
class BatchTripletLoss(nn.Module):
def __init__(self, margin=1.0, distance_metric="euclidean"):
:param margin: Margin for triplet loss
:param distance_metric: "euclidean" or "cosine"
super(BatchTripletLoss, self).__init__()
self.margin = margin
self.distance_metric = distance_metric
def _pairwise_distances(self, embeddings):
Compute pairwise distances between embeddings in the batch.
if self.distance_metric == "euclidean":
# Squared pairwise Euclidean distances
embeddings = F.normalize(embeddings, p=2, dim=1)
dot_product = torch.matmul(embeddings, embeddings.t())
square_norm = torch.diag(dot_product)
distances = square_norm.unsqueeze(1) - 2 * dot_product + square_norm.unsqueeze(0)
distances = torch.clamp(distances, min=0.0) # Avoid negative distances
return torch.sqrt(distances + 1e-12) # Add epsilon for numerical stability
elif self.distance_metric == "cosine":
# Cosine similarity -> distance
normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
cosine_similarity = torch.matmul(normalized_embeddings, normalized_embeddings.t())
distances = 1 - cosine_similarity # Cosine distance
return distances
raise ValueError(f"Unknown distance metric: {self.distance_metric}")
def forward(self, embeddings, labels):
Compute the triplet loss for a batch of embeddings and their corresponding labels.
:param embeddings: Tensor of shape [batch_size, embedding_dim]
:param labels: Tensor of shape [batch_size], integer class labels
# Compute pairwise distances
distances = self._pairwise_distances(embeddings)
# Mask for valid triplets (Anchor-Positive and Anchor-Negative pairs)
labels = labels.unsqueeze(1)
is_positive = labels.eq(labels.t()) # Positive mask
is_negative = ~is_positive # Negative mask
# For each anchor, find hardest positive and hardest negative
anchor_positive_distances = distances * is_positive.float() # Mask positive distances
hardest_positive_distances = anchor_positive_distances.max(dim=1)[0]
anchor_negative_distances = distances + 1e6 * is_positive.float() # Mask negative distances
hardest_negative_distances = anchor_negative_distances.min(dim=1)[0]
# Compute Triplet Loss
triplet_loss = F.relu(hardest_positive_distances - hardest_negative_distances + self.margin)
return triplet_loss.mean()
class CenterLoss(nn.Module):
def __init__(self, num_classes, feat_dim, alpha=0.5):
Center Loss 实现
:param num_classes: 类别数
:param feat_dim: 特征维度
:param alpha: 中心点更新的学习率
super(CenterLoss, self).__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.alpha = alpha
# 初始化类别中心点 [num_classes, feat_dim]
self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
def forward(self, features, labels):
计算 Center Loss
:param features: 输入特征 [batch_size, feat_dim]
:param labels: 对应标签 [batch_size]
:return: Center Loss
batch_size = features.size(0)
# 获取每个样本的类别中心点 [batch_size, feat_dim]
centers_batch = self.centers[labels]
# 计算 Center Loss
loss = F.mse_loss(features, centers_batch, reduction="mean")
# 手动更新中心点
self._update_centers(features, labels)
return loss
def _update_centers(self, features, labels):
:param features: 当前 batch 的样本特征 [batch_size, feat_dim]
:param labels: 当前 batch 的样本标签 [batch_size]
# 确保使用 torch.no_grad() 禁止梯度跟踪
with torch.no_grad():
unique_labels = labels.unique() # 当前 batch 中的类别
for label in unique_labels:
mask = labels == label # 筛选出该类别的样本
selected_features = features[mask] # 属于该类别的特征
if selected_features.size(0) > 0:
# 计算中心点增量
center_delta = selected_features.mean(dim=0) - self.centers[label]
# 使用非 in-place 更新
self.centers[label] = self.centers[label] + self.alpha * center_delta
class InterClassLoss(nn.Module):
def __init__(self, margin=0.0001):
super(InterClassLoss, self).__init__()
self.margin = margin
def forward(self, centers):
# 计算每对中心点之间的距离
num_classes = centers.size(0)
loss = 0
for i in range(num_classes):
for j in range(i+1, num_classes):
distance = torch.norm(centers[i] - centers[j])
loss += torch.max(torch.tensor(0.0), self.margin - distance)
return loss