File size: 643 Bytes
3953219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import monai
from .utils import load_config


def get_loss(config: dict): 
    """Create a loss function of `type` with specific keyword arguments from config.
    Example: 
        
        config.loss
        >>> {'DiceCELoss': {'include_background': False, 'softmax': True, 'to_onehot_y': True}}

        get_loss(config)
        >>> DiceCELoss(
        >>>   (dice): DiceLoss()
        >>>   (cross_entropy): CrossEntropyLoss()
        >>> )
    
    """
    loss_type = list(config.loss.keys())[0]
    loss_config = config.loss[loss_type]
    loss_fun =  getattr(monai.losses, loss_type)
    loss = loss_fun(**loss_config)
    return loss