Spaces:
Running
Running
File size: 2,447 Bytes
5769ee4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from typing import Optional
import torch
def FDE(
pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
):
"""
pred (Tensor): (..., time, xy)
truth (Tensor): (..., time, xy)
mask_loss (Tensor): (..., time) Defaults to None.
"""
if mask_loss is None:
return torch.mean(
torch.sqrt(
torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
)
)
else:
mask_loss = mask_loss.float()
return torch.sum(
torch.sqrt(
torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
)
* mask_loss[..., -1]
) / torch.sum(mask_loss[..., -1]).clamp_min(1)
def ADE(
pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
):
"""
pred (Tensor): (..., time, xy)
truth (Tensor): (..., time, xy)
mask_loss (Tensor): (..., time) Defaults to None.
"""
if mask_loss is None:
return torch.mean(
torch.sqrt(
torch.sum(torch.square(pred[..., :, :2] - truth[..., :, :2]), -1)
)
)
else:
mask_loss = mask_loss.float()
return torch.sum(
torch.sqrt(
torch.sum(torch.square(pred[..., :, :2] - truth[..., :, :2]), -1)
)
* mask_loss
) / torch.sum(mask_loss).clamp_min(1)
def minFDE(
pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
):
"""
pred (Tensor): (..., n_samples, time, xy)
truth (Tensor): (..., time, xy)
mask_loss (Tensor): (..., time) Defaults to None.
"""
if mask_loss is None:
min_distances, _ = torch.min(
torch.sqrt(
torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
),
-1,
)
return torch.mean(min_distances)
else:
mask_loss = mask_loss[..., -1].float()
final_distances = torch.sqrt(
torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
)
max_final_distance = torch.max(final_distances * mask_loss)
min_distances, _ = torch.min(
final_distances + max_final_distance * (1 - mask_loss), -1
)
return torch.sum(min_distances * mask_loss.any(-1)) / torch.sum(
mask_loss.any(-1)
).clamp_min(1)
|