import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .abstract_loss_func import AbstractLossClass | |
from metrics.registry import LOSSFUNC | |
class JS_Loss(AbstractLossClass): | |
def __init__(self): | |
super().__init__() | |
def forward(self, inputs, targets): | |
""" | |
Computes the Jensen-Shannon divergence loss. | |
""" | |
# Compute the probability distributions | |
inputs_prob = F.softmax(inputs, dim=1) | |
targets_prob = F.softmax(targets, dim=1) | |
# Compute the average probability distribution | |
avg_prob = (inputs_prob + targets_prob) / 2 | |
# Compute the KL divergence component for each distribution | |
kl_div_loss = nn.KLDivLoss(reduction='batchmean') | |
kl_inputs = kl_div_loss(inputs_prob.log(), avg_prob) | |
kl_targets = kl_div_loss(targets_prob.log(), avg_prob) | |
# Compute the Jensen-Shannon divergence | |
loss = 0.5 * (kl_inputs + kl_targets) | |
return loss |