File size: 3,903 Bytes
7088d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from pytorch3d.loss import mesh_edge_loss
from pytorch3d.structures import Meshes

from .common_testing import TestCaseMixin
from .test_sample_points_from_meshes import init_meshes


class TestMeshEdgeLoss(TestCaseMixin, unittest.TestCase):
    def test_empty_meshes(self):
        device = torch.device("cuda:0")
        target_length = 0
        N = 10
        V = 32
        verts_list = []
        faces_list = []
        for _ in range(N):
            vn = torch.randint(3, high=V, size=(1,))[0].item()
            verts = torch.rand((vn, 3), dtype=torch.float32, device=device)
            faces = torch.tensor([], dtype=torch.int64, device=device)
            verts_list.append(verts)
            faces_list.append(faces)
        mesh = Meshes(verts=verts_list, faces=faces_list)
        loss = mesh_edge_loss(mesh, target_length=target_length)

        self.assertClose(loss, torch.tensor([0.0], dtype=torch.float32, device=device))
        self.assertTrue(loss.requires_grad)

    @staticmethod
    def mesh_edge_loss_naive(meshes, target_length: float = 0.0):
        """
        Naive iterative implementation of mesh loss calculation.
        """
        edges_packed = meshes.edges_packed()
        verts_packed = meshes.verts_packed()
        edge_to_mesh = meshes.edges_packed_to_mesh_idx()
        N = len(meshes)
        device = meshes.device
        valid = meshes.valid
        predlosses = torch.zeros((N,), dtype=torch.float32, device=device)

        for b in range(N):
            if valid[b] == 0:
                continue
            mesh_edges = edges_packed[edge_to_mesh == b]
            verts_edges = verts_packed[mesh_edges]
            num_edges = mesh_edges.size(0)
            for e in range(num_edges):
                v0, v1 = verts_edges[e, 0], verts_edges[e, 1]
                predlosses[b] += ((v0 - v1).norm(dim=0, p=2) - target_length) ** 2.0

            if num_edges > 0:
                predlosses[b] = predlosses[b] / num_edges

        return predlosses.mean()

    def test_mesh_edge_loss_output(self):
        """
        Check outputs of tensorized and iterative implementations are the same.
        """
        device = torch.device("cuda:0")
        target_length = 0.5
        num_meshes = 10
        num_verts = 32
        num_faces = 64

        verts_list = []
        faces_list = []
        valid = torch.randint(2, size=(num_meshes,))

        for n in range(num_meshes):
            if valid[n]:
                vn = torch.randint(3, high=num_verts, size=(1,))[0].item()
                fn = torch.randint(vn, high=num_faces, size=(1,))[0].item()
                verts = torch.rand((vn, 3), dtype=torch.float32, device=device)
                faces = torch.randint(
                    vn, size=(fn, 3), dtype=torch.int64, device=device
                )
            else:
                verts = torch.tensor([], dtype=torch.float32, device=device)
                faces = torch.tensor([], dtype=torch.int64, device=device)
            verts_list.append(verts)
            faces_list.append(faces)
        meshes = Meshes(verts=verts_list, faces=faces_list)
        loss = mesh_edge_loss(meshes, target_length=target_length)

        predloss = TestMeshEdgeLoss.mesh_edge_loss_naive(meshes, target_length)
        self.assertClose(loss, predloss)

    @staticmethod
    def mesh_edge_loss(num_meshes: int = 10, max_v: int = 100, max_f: int = 300):
        meshes = init_meshes(num_meshes, max_v, max_f, device="cuda:0")
        torch.cuda.synchronize()

        def compute_loss():
            mesh_edge_loss(meshes, target_length=0.0)
            torch.cuda.synchronize()

        return compute_loss