File size: 2,135 Bytes
801501a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from ..custom_types import *


def get_gm_support(gm, x):
    dim = x.shape[-1]
    mu, p, phi, eigen = gm
    sigma_det = eigen.prod(-1)
    eigen_inv = 1 / eigen
    sigma_inverse = torch.matmul(p.transpose(3, 4), p * eigen_inv[:, :, :, :, None]).squeeze(1)
    phi = torch.softmax(phi, dim=2)
    const_1 = phi / torch.sqrt((2 * np.pi) ** dim * sigma_det)
    distance = x[:, :, None, :] - mu
    mahalanobis_distance = - .5 * torch.einsum('bngd,bgdc,bngc->bng', distance, sigma_inverse, distance)
    const_2, _ = mahalanobis_distance.max(dim=2)  # for numeric stability
    mahalanobis_distance -= const_2[:, :, None]
    support = const_1 * torch.exp(mahalanobis_distance)
    return support, const_2


def gm_log_likelihood_loss(gms: TS, x: T, get_supports: bool = False,
                           mask: Optional[T] = None, reduction: str = "mean") -> Union[T, Tuple[T, TS]]:

    batch_size, num_points, dim = x.shape
    support, const = get_gm_support(gms, x)
    probs = torch.log(support.sum(dim=2)) + const
    if mask is not None:
        probs = probs.masked_select(mask=mask.flatten())
    if reduction == 'none':
        likelihood = probs.sum(-1)
        loss = - likelihood / num_points
    else:
        likelihood = probs.sum()
        loss = - likelihood / (probs.shape[0] * probs.shape[1])
    if get_supports:
        return loss, support
    return loss


def split_mesh_by_gmm(mesh: T_Mesh, gmm) -> Dict[int, T]:
    faces_split = {}
    vs, faces = mesh
    vs_mid_faces = vs[faces].mean(1)
    _, supports = gm_log_likelihood_loss(gmm, vs_mid_faces.unsqueeze(0), get_supports=True)
    supports = supports[0]
    label = supports.argmax(1)
    for i in range(gmm[1].shape[2]):
        select = label.eq(i)
        if select.any():
            faces_split[i] = faces[select]
        else:
            faces_split[i] = None
    return faces_split


def flatten_gmm(gmm: TS) -> T:
    b, gp, g, _ = gmm[0].shape
    mu, p, phi, eigen = [item.view(b, gp * g, *item.shape[3:]) for item in gmm]
    p = p.reshape(*p.shape[:2], -1)
    z_gmm = torch.cat((mu, p, phi.unsqueeze(-1), eigen), dim=2)
    return z_gmm