|
import torch.nn as nn |
|
import torch |
|
from .abstract_loss_func import AbstractLossClass |
|
from metrics.registry import LOSSFUNC |
|
|
|
|
|
@LOSSFUNC.register_module(module_name="consistency_loss") |
|
class ConsistencyCos(nn.Module): |
|
def __init__(self): |
|
super(ConsistencyCos, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.loss_fn = nn.CrossEntropyLoss() |
|
self.mse_fn = nn.MSELoss() |
|
|
|
def forward(self, feat, inputs, targets): |
|
feat = nn.functional.normalize(feat, dim=1) |
|
feat_0 = feat[:int(feat.size(0)/2),:] |
|
feat_1 = feat[int(feat.size(0)/2): 2*int(feat.size(0)/2),:] |
|
|
|
cos = torch.einsum('nc,nc->n', [feat_0, feat_1]).unsqueeze(-1) |
|
labels = torch.ones((cos.shape[0],1), dtype=torch.float, requires_grad=False) |
|
if torch.cuda.is_available(): |
|
labels = labels.cuda() |
|
self.consistency_rate = 1.0 |
|
loss = self.consistency_rate * self.mse_fn(cos, labels) + self.loss_fn(inputs, targets) |
|
return loss |
|
|
|
|
|
|
|
class ConsistencyL2(nn.Module): |
|
def __init__(self): |
|
super(ConsistencyL2, self).__init__() |
|
self.mse_fn = nn.MSELoss() |
|
|
|
def forward(self, feat): |
|
feat_0 = feat[:int(feat.size(0)/2),:] |
|
feat_1 = feat[int(feat.size(0)/2):,:] |
|
loss = self.mse_fn(feat_0, feat_1) |
|
return loss |
|
|
|
class ConsistencyL1(nn.Module): |
|
def __init__(self): |
|
super(ConsistencyL1, self).__init__() |
|
self.L1_fn = nn.L1Loss() |
|
|
|
def forward(self, feat): |
|
feat_0 = feat[:int(feat.size(0)/2),:] |
|
feat_1 = feat[int(feat.size(0)/2):,:] |
|
loss = self.L1_fn(feat_0, feat_1) |
|
return loss |