jpdefrutos commited on
Commit
20d4b07
·
1 Parent(s): 2676c58

Removed NCC, training only on Hausdorff and DICE

Browse files
TrainingScripts/Train_3d_weaklySupervised.py CHANGED
@@ -54,15 +54,15 @@ def dice_loss(y_true, y_pred):
54
  # Dice().loss returns -Dice score
55
  return 1 + vxm.losses.Dice().loss(y_true, y_pred)
56
 
57
- multiLoss = UncertaintyWeighting(num_loss_fns=3,
58
  num_reg_fns=1,
59
- loss_fns=[HausdorffDistance(3, 5).loss, dice_loss, vxm.losses.NCC().loss],
60
  reg_fns=[vxm.losses.Grad('l2').loss],
61
  prior_loss_w=[1., 1., 1.],
62
  prior_reg_w=[0.01],
63
  name='MultiLossLayer')
64
- loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1], fix_img,
65
- vxm_model.references.pred_segm, vxm_model.references.pred_segm, vxm_model.references.pred_img,
66
  grad,
67
  vxm_model.references.pos_flow])
68
 
 
54
  # Dice().loss returns -Dice score
55
  return 1 + vxm.losses.Dice().loss(y_true, y_pred)
56
 
57
+ multiLoss = UncertaintyWeighting(num_loss_fns=2,
58
  num_reg_fns=1,
59
+ loss_fns=[HausdorffDistance(3, 5).loss, dice_loss],
60
  reg_fns=[vxm.losses.Grad('l2').loss],
61
  prior_loss_w=[1., 1., 1.],
62
  prior_reg_w=[0.01],
63
  name='MultiLossLayer')
64
+ loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1],
65
+ vxm_model.references.pred_segm, vxm_model.references.pred_segm,
66
  grad,
67
  vxm_model.references.pos_flow])
68