Spaces:
Sleeping
Sleeping
from ..custom_types import * | |
def occupancy_bce(predict: T, winding_gt: T, ignore: Optional[T] = None, *args) -> T: | |
if winding_gt.dtype is not torch.bool: | |
winding_gt = winding_gt.gt(0) | |
labels = winding_gt.flatten().float() | |
predict = predict.flatten() | |
if ignore is not None: | |
ignore = (~ignore).flatten().float() | |
loss = nnf.binary_cross_entropy_with_logits(predict, labels, weight=ignore) | |
return loss | |
def reg_z_loss(z: T) -> T: | |
norms = z.norm(2, 1) | |
loss = norms.mean() | |
return loss | |