File size: 2,447 Bytes
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import Optional

import torch


def FDE(
    pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
):
    """
    pred (Tensor): (..., time, xy)
    truth (Tensor): (..., time, xy)
    mask_loss (Tensor): (..., time) Defaults to None.
    """
    if mask_loss is None:
        return torch.mean(
            torch.sqrt(
                torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
            )
        )
    else:
        mask_loss = mask_loss.float()
        return torch.sum(
            torch.sqrt(
                torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
            )
            * mask_loss[..., -1]
        ) / torch.sum(mask_loss[..., -1]).clamp_min(1)


def ADE(
    pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
):
    """
    pred (Tensor): (..., time, xy)
    truth (Tensor): (..., time, xy)
    mask_loss (Tensor): (..., time) Defaults to None.
    """
    if mask_loss is None:
        return torch.mean(
            torch.sqrt(
                torch.sum(torch.square(pred[..., :, :2] - truth[..., :, :2]), -1)
            )
        )
    else:
        mask_loss = mask_loss.float()
        return torch.sum(
            torch.sqrt(
                torch.sum(torch.square(pred[..., :, :2] - truth[..., :, :2]), -1)
            )
            * mask_loss
        ) / torch.sum(mask_loss).clamp_min(1)


def minFDE(
    pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
):
    """
    pred (Tensor): (..., n_samples, time, xy)
    truth (Tensor): (..., time, xy)
    mask_loss (Tensor): (..., time) Defaults to None.
    """
    if mask_loss is None:
        min_distances, _ = torch.min(
            torch.sqrt(
                torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
            ),
            -1,
        )
        return torch.mean(min_distances)
    else:
        mask_loss = mask_loss[..., -1].float()
        final_distances = torch.sqrt(
            torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
        )
        max_final_distance = torch.max(final_distances * mask_loss)
        min_distances, _ = torch.min(
            final_distances + max_final_distance * (1 - mask_loss), -1
        )
        return torch.sum(min_distances * mask_loss.any(-1)) / torch.sum(
            mask_loss.any(-1)
        ).clamp_min(1)