anyantudre's picture
moved from training repo to inference
caa56d6
import torch
import torch.nn as nn
import torch.nn.functional as F
from .abstract_loss_func import AbstractLossClass
from metrics.registry import LOSSFUNC
@LOSSFUNC.register_module(module_name="jsloss")
class JS_Loss(AbstractLossClass):
def __init__(self):
super().__init__()
def forward(self, inputs, targets):
"""
Computes the Jensen-Shannon divergence loss.
"""
# Compute the probability distributions
inputs_prob = F.softmax(inputs, dim=1)
targets_prob = F.softmax(targets, dim=1)
# Compute the average probability distribution
avg_prob = (inputs_prob + targets_prob) / 2
# Compute the KL divergence component for each distribution
kl_div_loss = nn.KLDivLoss(reduction='batchmean')
kl_inputs = kl_div_loss(inputs_prob.log(), avg_prob)
kl_targets = kl_div_loss(targets_prob.log(), avg_prob)
# Compute the Jensen-Shannon divergence
loss = 0.5 * (kl_inputs + kl_targets)
return loss