Spaces:
Runtime error
Runtime error
File size: 1,631 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
from mmengine.model import BaseModule
from torch import nn
from mmpretrain.registry import MODELS
@MODELS.register_module()
class CosineSimilarityLoss(BaseModule):
"""Cosine similarity loss function.
Compute the similarity between two features and optimize that similarity as
loss.
Args:
shift_factor (float): The shift factor of cosine similarity.
Default: 0.0.
scale_factor (float): The scale factor of cosine similarity.
Default: 1.0.
"""
def __init__(self,
shift_factor: float = 0.0,
scale_factor: float = 1.0) -> None:
super().__init__()
self.shift_factor = shift_factor
self.scale_factor = scale_factor
def forward(self,
pred: torch.Tensor,
target: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward function of cosine similarity loss.
Args:
pred (torch.Tensor): The predicted features.
target (torch.Tensor): The target features.
Returns:
torch.Tensor: The cosine similarity loss.
"""
pred_norm = nn.functional.normalize(pred, dim=-1)
target_norm = nn.functional.normalize(target, dim=-1)
loss = self.shift_factor - self.scale_factor * (
pred_norm * target_norm).sum(dim=-1)
if mask is None:
loss = loss.mean()
else:
loss = (loss * mask).sum() / mask.sum()
return loss
|