File size: 1,123 Bytes
482ab8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
import torch.nn as nn
def get_map_label_loss(opt):
return MapLabelLoss(opt.label_loss_on_whole_map)
class MapLabelLoss(nn.Module):
def __init__(self, label_loss_on_whole_map=False):
super().__init__()
self.bce_loss = nn.BCELoss(reduction="none")
self.label_loss_on_whole_map = label_loss_on_whole_map
def forward(self, pred, out_map, label):
batch_size = label.shape[0]
if (
self.label_loss_on_whole_map
): # apply the loss on the whole map for pristine images
total_loss = 0
for i in range(batch_size):
if label[i] == 0: # pristine
total_loss = (
total_loss
+ self.bce_loss(out_map[i, ...].mean(), label[i]).mean()
)
else: # modified
total_loss = total_loss + self.bce_loss(pred[i], label[i]).mean()
loss = total_loss / batch_size
else:
loss = self.bce_loss(pred, label)
loss = loss.mean()
return {"loss": loss}
|