File size: 2,820 Bytes
af720c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from typing import Dict, Callable
import torch.nn as nn
import torch
from losses import SoftDiceLoss, SSLoss, IoULoss, TverskyLoss, FocalTversky_loss, AsymLoss, ExpLog_loss, FocalLoss, LovaszSoftmax, TopKLoss, WeightedCrossEntropyLoss, SoftDiceLoss_v2, IoULoss_v2, TverskyLoss_v2, FocalTversky_loss_v2, AsymLoss_v2, SSLoss_v2
def get_loss(loss_type: str) -> Callable | None:
if loss_type == "cross_entropy":
return nn.CrossEntropyLoss()
elif loss_type == "SoftDiceLoss":
return SoftDiceLoss()
elif loss_type == "SSLoss":
return SSLoss()
elif loss_type == "IoULoss":
return IoULoss()
elif loss_type == "TverskyLoss":
return TverskyLoss()
elif loss_type == "FocalTversky_loss":
tversky_kwargs = {
"apply_nonlin": None,
"batch_dice": False,
"do_bg": True,
"smooth": 1.0,
"square": False
}
return FocalTversky_loss(tversky_kwargs=tversky_kwargs)
elif loss_type == "AsymLoss":
return AsymLoss()
elif loss_type == "ExpLog_loss":
soft_dice_kwargs = {
"smooth": 1.0
}
wce_kwargs = {
"weight": None
}
return ExpLog_loss(soft_dice_kwargs=soft_dice_kwargs, wce_kwargs=wce_kwargs)
elif loss_type == "FocalLoss":
return FocalLoss()
elif loss_type == "LovaszSoftmax":
return LovaszSoftmax()
elif loss_type == "TopKLoss":
return TopKLoss()
elif loss_type == "WeightedCrossEntropyLoss":
return WeightedCrossEntropyLoss()
elif loss_type == "SoftDiceLoss_v2":
return SoftDiceLoss_v2()
elif loss_type == "IoULoss_v2":
return IoULoss_v2()
elif loss_type == "TverskyLoss_v2":
return TverskyLoss_v2()
elif loss_type == "FocalTversky_loss_v2":
return FocalTversky_loss_v2()
elif loss_type == "AsymLoss_v2":
return AsymLoss_v2()
elif loss_type == "SSLoss_v2":
return SSLoss_v2()
else:
raise ValueError(f"Unsupported loss type: {loss_type}")
def get_composite_criterion(losses_config: Dict[str, float]) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
losses = []
weights = []
for loss_name, weight in losses_config.items():
if weight != 0.0:
loss_fn = get_loss(loss_name)
if loss_fn is not None:
losses.append(loss_fn)
weights.append(weight)
def composite_loss(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
total_loss = 0.0
for loss_fn, weight in zip(losses, weights):
total_loss += weight * loss_fn(output, target)
return total_loss
return composite_loss |