File size: 1,925 Bytes
482ab8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from .bundled_loss import BundledLoss
from .consisitency_loss import get_consistency_loss, get_volume_seg_map
from .entropy_loss import get_entropy_loss
from .loss import Loss
from .map_label_loss import get_map_label_loss
from .map_mask_loss import get_map_mask_loss
from .multi_view_consistency_loss import (
get_multi_view_consistency_loss,
get_spixel_tgt_map,
)
from .volume_label_loss import get_volume_label_loss
from .volume_mask_loss import get_volume_mask_loss
def get_bundled_loss(opt):
"""Loss function for the overeall training, including the multi-view
consistency loss."""
single_modality_loss = get_loss(opt)
multi_view_consistency_loss = get_multi_view_consistency_loss(opt)
volume_mask_loss = get_volume_mask_loss(opt)
bundled_loss = BundledLoss(
single_modality_loss,
multi_view_consistency_loss,
volume_mask_loss,
opt.mvc_weight,
opt.mvc_time_dependent,
opt.mvc_steepness,
opt.modality,
opt.consistency_weight,
opt.consistency_source,
)
return bundled_loss
def get_loss(opt):
"""Loss function for a single model, excluding the multi-view consistency
loss."""
map_label_loss = get_map_label_loss(opt)
volume_label_loss = get_volume_label_loss(opt)
map_mask_loss = get_map_mask_loss(opt)
volume_mask_loss = get_volume_mask_loss(opt)
consisitency_loss = get_consistency_loss(opt)
entropy_loss = get_entropy_loss(opt)
loss = Loss(
map_label_loss,
volume_label_loss,
map_mask_loss,
volume_mask_loss,
consisitency_loss,
entropy_loss,
opt.map_label_weight,
opt.volume_label_weight,
opt.map_mask_weight,
opt.volume_mask_weight,
opt.consistency_weight,
opt.map_entropy_weight,
opt.volume_entropy_weight,
opt.consistency_source,
)
return loss
|