File size: 5,440 Bytes
744eb4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#-------------------------------------------------------------------------------
# Name:        utils.py
# Purpose:     utilize for Loss function in RigNet
# RigNet Copyright 2020 University of Massachusetts
# RigNet is made available under General Public License Version 3 (GPLv3), or under a Commercial License.
# Please see the LICENSE README.txt file in the main directory for more information and instruction on using and licensing RigNet.
#-------------------------------------------------------------------------------


from apted import APTED, Config
import numpy as np

class CustomConfig(Config):
    valuecls = float

    def rename(self, node1, node2):
        """Compares attribute .value of trees"""
        # return 1 if node1.value != node2.value else 0
        # if not node1 or not node2:
        #     return 1.0
        # return np.sqrt(np.sum((np.array(node1.pos) - np.array(node2.pos))**2))
        return 0

    def children(self, node):
        """Get left and right children of binary tree"""
        # return [x for x in (node.left, node.right) if x]
        if not node:
            return list()
        else:
            return node.children


def getJointNum(skel):
    this_level = [skel.root]
    n_joint = 1
    while this_level:
        next_level = []
        for p_node in this_level:
            n_joint += len(p_node.children)
            next_level += p_node.children
        this_level = next_level
    return n_joint


def dist_pts2bone(pts, pos_1, pos_2):
    l2 = np.sum((pos_2 - pos_1) ** 2)
    if l2 < 1e-10:
        dist_to_lineseg = np.linalg.norm(pts - pos_1, axis=1)
        dist_proj = np.linalg.norm(pts - pos_1, axis=1)
    else:
        t_ = np.sum((pts - pos_1[np.newaxis, :]) * (pos_2 - pos_1), axis=1) / l2
        t = np.clip(t_, 0, 1)
        t_pos = pos_1[np.newaxis, :] + t[:, np.newaxis] * (pos_2 - pos_1)[np.newaxis, :]
        lineseg_len = np.linalg.norm(pos_2 - pos_1)
        dist_proj = np.zeros(len(t_))
        dist_proj[np.argwhere(t_ < 0.5).squeeze()] = np.abs(t_[np.argwhere(t_ < 0.5).squeeze()] - 0.0) * lineseg_len
        dist_proj[np.argwhere(t_ >= 0.5).squeeze()] = np.abs(t_[np.argwhere(t_ >= 0.5).squeeze()] - 1.0) * lineseg_len
        dist_to_lineseg = np.linalg.norm(pts - t_pos, axis=1)
    return dist_to_lineseg, dist_proj


def chamfer_dist(pt1, pt2):
    pt1 = pt1[np.newaxis, :, :]
    pt2 = pt2[:, np.newaxis, :]
    dist = np.sqrt(np.sum((pt1 - pt2) ** 2, axis=2))
    min_left = np.mean(np.min(dist, axis=0))
    min_right = np.mean(np.min(dist, axis=1))
    #print(min_left, min_right)
    return (min_left + min_right) / 2


def oneway_chamfer(pt_src, pt_dst):
    pt1 = pt_src[np.newaxis, :, :]
    pt2 = pt_dst[:, np.newaxis, :]
    dist = np.sqrt(np.sum((pt1 - pt2) ** 2, axis=2))
    avg_dist = np.mean(np.min(dist, axis=0))
    return avg_dist


def getJointArr(skel):
    joints = []
    this_level = [skel.root]
    while this_level:
        next_level = []
        for p_node in this_level:
            joint_ = np.array(p_node.pos)
            joint_ = joint_[np.newaxis, :]
            joints.append(joint_)
            next_level += p_node.children
        this_level = next_level
    joints = np.concatenate(joints, axis=0)
    return joints


def edit_dist(tree1, tree2):
    #n_joint1 = getJointNum(tree2)
    #n_joint2 = getJointNum(tree2)
    apted = APTED(tree1.root, tree2.root, CustomConfig())
    ted = apted.compute_edit_distance()
    #ted /= max(n_joint1, n_joint2)
    return ted


def tree_dist(tree1, tree2, ted_weight):
    # get edit distance
    ted = edit_dist(tree1, tree2)

    # get chamfer distance
    joint_arr_1 = getJointArr(tree1)
    joint_arr_2 = getJointArr(tree2)
    cd = chamfer_dist(joint_arr_1, joint_arr_2)

    return (1-ted_weight)*cd + ted_weight * ted


def sample_bone(p_pos, ch_pos):
    ray = ch_pos - p_pos
    bone_length = np.sqrt(np.sum((p_pos - ch_pos) ** 2))
    num_step = np.round(bone_length / 0.005)
    i_step = np.arange(0, num_step + 1)
    unit_step = ray / (num_step + 1e-30)
    unit_step = np.repeat(unit_step, num_step+1, axis=0)
    res = p_pos + unit_step * i_step[:, np.newaxis]
    return res


def sample_skel(skel):
    bone_sample = []
    this_level = [skel.root]
    while this_level:
        next_level = []
        for p_node in this_level:
            p_pos = np.array([p_node.pos])
            next_level += p_node.children
            for c_node in p_node.children:
                ch_pos = np.array([c_node.pos])
                res = sample_bone(p_pos, ch_pos)
                bone_sample.append(res)
        this_level = next_level
    bone_sample = np.concatenate(bone_sample, axis=0)
    return bone_sample


def bone2bone_chamfer_dist(skel_1, skel_2):
    bone_sample_1 = sample_skel(skel_1)
    bone_sample_2 = sample_skel(skel_2)
    pt1 = bone_sample_1[np.newaxis, :, :]
    pt2 = bone_sample_2[:, np.newaxis, :]
    dist = np.sqrt(np.sum((pt1 - pt2) ** 2, axis=2))
    min_left = np.mean(np.min(dist, axis=0))
    min_right = np.mean(np.min(dist, axis=1))
    # print(min_left, min_right)
    return (min_left + min_right) / 2


def joint2bone_chamfer_dist(skel1, skel2):
    bone_sample_1 = sample_skel(skel1)
    bone_sample_2 = sample_skel(skel2)
    joint_1 = getJointArr(skel1)
    joint_2 = getJointArr(skel2)
    dist1 = oneway_chamfer(joint_1, bone_sample_2)
    dist2 = oneway_chamfer(joint_2, bone_sample_1)
    return (dist1 + dist2) / 2