File size: 343 Bytes
482ab8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch.nn as nn
def get_volume_label_loss(opt):
return VolumeLabelLoss()
class VolumeLabelLoss(nn.Module):
def __init__(self):
super().__init__()
self.BCE_loss = nn.BCELoss(reduction="mean")
def forward(self, pred, volume, label):
loss = self.BCE_loss(pred, label)
return {"loss": loss}
|