Spaces:
Runtime error
Runtime error
import torch | |
from isegm.model.losses import SigmoidBinaryCrossEntropyLoss | |
class BRSMaskLoss(torch.nn.Module): | |
def __init__(self, eps=1e-5): | |
super().__init__() | |
self._eps = eps | |
def forward(self, result, pos_mask, neg_mask): | |
pos_diff = (1 - result) * pos_mask | |
pos_target = torch.sum(pos_diff**2) | |
pos_target = pos_target / (torch.sum(pos_mask) + self._eps) | |
neg_diff = result * neg_mask | |
neg_target = torch.sum(neg_diff**2) | |
neg_target = neg_target / (torch.sum(neg_mask) + self._eps) | |
loss = pos_target + neg_target | |
with torch.no_grad(): | |
f_max_pos = torch.max(torch.abs(pos_diff)).item() | |
f_max_neg = torch.max(torch.abs(neg_diff)).item() | |
return loss, f_max_pos, f_max_neg | |
class OracleMaskLoss(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.gt_mask = None | |
self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True) | |
self.predictor = None | |
self.history = [] | |
def set_gt_mask(self, gt_mask): | |
self.gt_mask = gt_mask | |
self.history = [] | |
def forward(self, result, pos_mask, neg_mask): | |
gt_mask = self.gt_mask.to(result.device) | |
if self.predictor.object_roi is not None: | |
r1, r2, c1, c2 = self.predictor.object_roi[:4] | |
gt_mask = gt_mask[:, :, r1 : r2 + 1, c1 : c2 + 1] | |
gt_mask = torch.nn.functional.interpolate( | |
gt_mask, result.size()[2:], mode="bilinear", align_corners=True | |
) | |
if result.shape[0] == 2: | |
gt_mask_flipped = torch.flip(gt_mask, dims=[3]) | |
gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0) | |
loss = self.loss(result, gt_mask) | |
self.history.append(loss.detach().cpu().numpy()[0]) | |
if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5: | |
return 0, 0, 0 | |
return loss, 1.0, 1.0 | |