|
from typing import Dict, Callable
|
|
import torch.nn as nn
|
|
import torch
|
|
|
|
from losses import SoftDiceLoss, SSLoss, IoULoss, TverskyLoss, FocalTversky_loss, AsymLoss, ExpLog_loss, FocalLoss, LovaszSoftmax, TopKLoss, WeightedCrossEntropyLoss, SoftDiceLoss_v2, IoULoss_v2, TverskyLoss_v2, FocalTversky_loss_v2, AsymLoss_v2, SSLoss_v2
|
|
|
|
|
|
def get_loss(loss_type: str) -> Callable | None:
|
|
if loss_type == "cross_entropy":
|
|
return nn.CrossEntropyLoss()
|
|
elif loss_type == "SoftDiceLoss":
|
|
return SoftDiceLoss()
|
|
elif loss_type == "SSLoss":
|
|
return SSLoss()
|
|
elif loss_type == "IoULoss":
|
|
return IoULoss()
|
|
elif loss_type == "TverskyLoss":
|
|
return TverskyLoss()
|
|
elif loss_type == "FocalTversky_loss":
|
|
tversky_kwargs = {
|
|
"apply_nonlin": None,
|
|
"batch_dice": False,
|
|
"do_bg": True,
|
|
"smooth": 1.0,
|
|
"square": False
|
|
}
|
|
return FocalTversky_loss(tversky_kwargs=tversky_kwargs)
|
|
elif loss_type == "AsymLoss":
|
|
return AsymLoss()
|
|
elif loss_type == "ExpLog_loss":
|
|
soft_dice_kwargs = {
|
|
"smooth": 1.0
|
|
}
|
|
wce_kwargs = {
|
|
"weight": None
|
|
}
|
|
return ExpLog_loss(soft_dice_kwargs=soft_dice_kwargs, wce_kwargs=wce_kwargs)
|
|
elif loss_type == "FocalLoss":
|
|
return FocalLoss()
|
|
elif loss_type == "LovaszSoftmax":
|
|
return LovaszSoftmax()
|
|
elif loss_type == "TopKLoss":
|
|
return TopKLoss()
|
|
elif loss_type == "WeightedCrossEntropyLoss":
|
|
return WeightedCrossEntropyLoss()
|
|
elif loss_type == "SoftDiceLoss_v2":
|
|
return SoftDiceLoss_v2()
|
|
elif loss_type == "IoULoss_v2":
|
|
return IoULoss_v2()
|
|
elif loss_type == "TverskyLoss_v2":
|
|
return TverskyLoss_v2()
|
|
elif loss_type == "FocalTversky_loss_v2":
|
|
return FocalTversky_loss_v2()
|
|
elif loss_type == "AsymLoss_v2":
|
|
return AsymLoss_v2()
|
|
elif loss_type == "SSLoss_v2":
|
|
return SSLoss_v2()
|
|
else:
|
|
raise ValueError(f"Unsupported loss type: {loss_type}")
|
|
|
|
|
|
def get_composite_criterion(losses_config: Dict[str, float]) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
|
losses = []
|
|
weights = []
|
|
|
|
for loss_name, weight in losses_config.items():
|
|
if weight != 0.0:
|
|
loss_fn = get_loss(loss_name)
|
|
if loss_fn is not None:
|
|
losses.append(loss_fn)
|
|
weights.append(weight)
|
|
|
|
def composite_loss(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
total_loss = 0.0
|
|
for loss_fn, weight in zip(losses, weights):
|
|
total_loss += weight * loss_fn(output, target)
|
|
return total_loss
|
|
|
|
return composite_loss |