M3Site / utils /util_classes.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
import torch
import math
import torch.nn.functional as F
import torch.nn as nn
class MyPrint():
def __init__(self, logger):
self.logger = logger
def pprint(self, *args):
print(*args)
log_message = ', '.join(str(arg) for arg in args)
self.logger.info(log_message)
class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, total_steps, warmup_ratio=0.1, last_epoch=-1):
self.total_steps = total_steps
self.warmup_steps = total_steps * warmup_ratio
super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
lr_scale = (self.last_epoch + 1) / self.warmup_steps
return [base_lr * lr_scale for base_lr in self.base_lrs]
else:
cosine_step = self.last_epoch - self.warmup_steps
cosine_steps = self.total_steps - self.warmup_steps
return [
base_lr * (1 + math.cos(math.pi * cosine_step / cosine_steps)) / 2
for base_lr in self.base_lrs
]
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.25):
super(FocalLoss, self).__init__()
self.gamma = gamma # Focusing parameter
self.alpha = alpha # Balance parameter
def forward(self, probs, targets, eps=1e-7):
target_probs = probs.gather(dim=-1, index=targets.unsqueeze(-1))
cross_entropy_loss = -torch.log(target_probs + eps)
modulating_factor = (1 - target_probs) ** self.gamma
focal_loss = self.alpha * modulating_factor * cross_entropy_loss
return focal_loss.mean()
class BatchTripletLoss(nn.Module):
def __init__(self, margin=1.0, distance_metric="euclidean"):
"""
:param margin: Margin for triplet loss
:param distance_metric: "euclidean" or "cosine"
"""
super(BatchTripletLoss, self).__init__()
self.margin = margin
self.distance_metric = distance_metric
def _pairwise_distances(self, embeddings):
"""
Compute pairwise distances between embeddings in the batch.
"""
if self.distance_metric == "euclidean":
# Squared pairwise Euclidean distances
embeddings = F.normalize(embeddings, p=2, dim=1)
dot_product = torch.matmul(embeddings, embeddings.t())
square_norm = torch.diag(dot_product)
distances = square_norm.unsqueeze(1) - 2 * dot_product + square_norm.unsqueeze(0)
distances = torch.clamp(distances, min=0.0) # Avoid negative distances
return torch.sqrt(distances + 1e-12) # Add epsilon for numerical stability
elif self.distance_metric == "cosine":
# Cosine similarity -> distance
normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
cosine_similarity = torch.matmul(normalized_embeddings, normalized_embeddings.t())
distances = 1 - cosine_similarity # Cosine distance
return distances
else:
raise ValueError(f"Unknown distance metric: {self.distance_metric}")
def forward(self, embeddings, labels):
"""
Compute the triplet loss for a batch of embeddings and their corresponding labels.
:param embeddings: Tensor of shape [batch_size, embedding_dim]
:param labels: Tensor of shape [batch_size], integer class labels
"""
# Compute pairwise distances
distances = self._pairwise_distances(embeddings)
# Mask for valid triplets (Anchor-Positive and Anchor-Negative pairs)
labels = labels.unsqueeze(1)
is_positive = labels.eq(labels.t()) # Positive mask
is_negative = ~is_positive # Negative mask
# For each anchor, find hardest positive and hardest negative
anchor_positive_distances = distances * is_positive.float() # Mask positive distances
hardest_positive_distances = anchor_positive_distances.max(dim=1)[0]
anchor_negative_distances = distances + 1e6 * is_positive.float() # Mask negative distances
hardest_negative_distances = anchor_negative_distances.min(dim=1)[0]
# Compute Triplet Loss
triplet_loss = F.relu(hardest_positive_distances - hardest_negative_distances + self.margin)
return triplet_loss.mean()
class CenterLoss(nn.Module):
def __init__(self, num_classes, feat_dim, alpha=0.5):
"""
Center Loss 实现
:param num_classes: 类别数
:param feat_dim: 特征维度
:param alpha: 中心点更新的学习率
"""
super(CenterLoss, self).__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.alpha = alpha
# 初始化类别中心点 [num_classes, feat_dim]
self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
def forward(self, features, labels):
"""
计算 Center Loss
:param features: 输入特征 [batch_size, feat_dim]
:param labels: 对应标签 [batch_size]
:return: Center Loss
"""
batch_size = features.size(0)
# 获取每个样本的类别中心点 [batch_size, feat_dim]
centers_batch = self.centers[labels]
# 计算 Center Loss
loss = F.mse_loss(features, centers_batch, reduction="mean")
# 手动更新中心点
self._update_centers(features, labels)
return loss
def _update_centers(self, features, labels):
"""
更新类别中心点
:param features: 当前 batch 的样本特征 [batch_size, feat_dim]
:param labels: 当前 batch 的样本标签 [batch_size]
"""
# 确保使用 torch.no_grad() 禁止梯度跟踪
with torch.no_grad():
unique_labels = labels.unique() # 当前 batch 中的类别
for label in unique_labels:
mask = labels == label # 筛选出该类别的样本
selected_features = features[mask] # 属于该类别的特征
if selected_features.size(0) > 0:
# 计算中心点增量
center_delta = selected_features.mean(dim=0) - self.centers[label]
# 使用非 in-place 更新
self.centers[label] = self.centers[label] + self.alpha * center_delta
class InterClassLoss(nn.Module):
def __init__(self, margin=0.0001):
super(InterClassLoss, self).__init__()
self.margin = margin
def forward(self, centers):
# 计算每对中心点之间的距离
num_classes = centers.size(0)
loss = 0
for i in range(num_classes):
for j in range(i+1, num_classes):
distance = torch.norm(centers[i] - centers[j])
loss += torch.max(torch.tensor(0.0), self.margin - distance)
return loss