File size: 12,502 Bytes
753fd9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217


import torch
import numpy as np
import pickle as pkl

import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
# from priors.pose_prior_35 import Prior
# from priors.tiger_pose_prior.tiger_pose_prior import GaussianMixturePrior
from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior
from priors.shape_prior import ShapePrior
from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, batch_rot2aa, geodesic_loss_R
from combined_model.loss_utils.loss_utils import leg_sideway_error, leg_torsion_error, tail_sideway_error, tail_torsion_error, spine_torsion_error, spine_sideway_error
from combined_model.loss_utils.loss_utils_gc import LossGConMesh, calculate_plane_errors_batch

from priors.shape_prior import ShapePrior
from configs.SMAL_configs import SMAL_MODEL_CONFIG

from priors.helper_3dcgmodel_loss import load_dog_betas_for_3dcgmodel_loss


class LossRef(torch.nn.Module):
    def __init__(self, smal_model_type, data_info, nf_version=None):
        super(LossRef, self).__init__()
        self.criterion_regr = torch.nn.MSELoss()        # takes the mean   
        self.criterion_class = torch.nn.CrossEntropyLoss()

        class_weights_isflat = torch.tensor([12, 2])
        self.criterion_class_isflat = torch.nn.CrossEntropyLoss(weight=class_weights_isflat)
        self.criterion_l1 = torch.nn.L1Loss()
        self.geodesic_loss = geodesic_loss_R(reduction='mean')
        self.gc_loss_on_mesh = LossGConMesh()
        self.data_info = data_info   
        self.smal_model_type = smal_model_type
        self.register_buffer('keypoint_weights', torch.tensor(data_info.keypoint_weights)[None, :])
        # if nf_version is not None:
        #     self.normalizing_flow_pose_prior = NormalizingFlowPrior(nf_version=nf_version)

        self.smal_model_data_path = SMAL_MODEL_CONFIG[self.smal_model_type]['smal_model_data_path']
        self.shape_prior = ShapePrior(self.smal_model_data_path) # here we just need mean and cov        

        remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl'
        with open(remeshing_path, 'rb') as fp: 
            self.remeshing_dict = pkl.load(fp)
        self.remeshing_relevant_faces = torch.tensor(self.remeshing_dict['smal_faces'][self.remeshing_dict['faceid_closest']], dtype=torch.long)
        self.remeshing_relevant_barys = torch.tensor(self.remeshing_dict['barys_closest'], dtype=torch.float32)



        # load 3d data for the unity dogs (an optional shape prior for 11 breeds)
        self.unity_smal_shape_prior_dogs = SMAL_MODEL_CONFIG[self.smal_model_type]['unity_smal_shape_prior_dogs']
        if self.unity_smal_shape_prior_dogs is not None:
            self.dog_betas_unity = load_dog_betas_for_3dcgmodel_loss(self.unity_smal_shape_prior_dogs, self.smal_model_type)
        else:
            self.dog_betas_unity = None







    def forward(self, output_ref, output_ref_comp, target_dict, weight_dict_ref):
        # output_reproj: ['vertices_smal', 'keyp_3d', 'keyp_2d', 'silh_image']
        # target_dict: ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight']
        batch_size = output_ref['keyp_2d'].shape[0]
        loss_dict_temp = {}

        # loss on reprojected keypoints 
        output_kp_resh = (output_ref['keyp_2d']).reshape((-1, 2))    
        target_kp_resh = (target_dict['tpts'][:, :, :2] / 64. * (256. - 1)).reshape((-1, 2))
        weights_resh = target_dict['tpts'][:, :, 2].reshape((-1)) 
        keyp_w_resh = self.keypoint_weights.repeat((batch_size, 1)).reshape((-1))
        loss_dict_temp['keyp_ref'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \
            max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5)

        # loss on reprojected silhouette
        assert output_ref['silh'].shape == (target_dict['silh'][:, None, :, :]).shape
        silh_loss_type = 'default'
        if silh_loss_type == 'default':
            with torch.no_grad():
                thr_silh = 20
                diff = torch.norm(output_kp_resh - target_kp_resh, dim=1)
                diff_x = diff.reshape((batch_size, -1))
                weights_resh_x = weights_resh.reshape((batch_size, -1))
                unweighted_kp_mean_dist = (diff_x * weights_resh_x).sum(dim=1) / ((weights_resh_x).sum(dim=1)+1e-6)
            loss_silh_bs = ((output_ref['silh'] - target_dict['silh'][:, None, :, :]) ** 2).sum(axis=3).sum(axis=2).sum(axis=1) / (output_ref['silh'].shape[2]*output_ref['silh'].shape[3])
            loss_dict_temp['silh_ref'] = loss_silh_bs[unweighted_kp_mean_dist<thr_silh].sum() / batch_size
        else:
            print('silh_loss_type: ' + silh_loss_type)
            raise ValueError

        # regularization: losses on difference between previous prediction and refinement
        loss_dict_temp['reg_trans'] = self.criterion_l1(output_ref_comp['ref_trans_notnorm'], output_ref_comp['old_trans_notnorm'].detach()) * 3
        loss_dict_temp['reg_flength'] = self.criterion_l1(output_ref_comp['ref_flength_notnorm'], output_ref_comp['old_flength_notnorm'].detach()) * 1
        loss_dict_temp['reg_pose'] = self.geodesic_loss(output_ref_comp['ref_pose_rotmat'], output_ref_comp['old_pose_rotmat'].detach()) * 35 * 6

        # pose priors on refined pose
        loss_dict_temp['pose_legs_side'] = leg_sideway_error(output_ref['pose_rotmat'])
        loss_dict_temp['pose_legs_tors'] = leg_torsion_error(output_ref['pose_rotmat'])
        loss_dict_temp['pose_tail_side'] = tail_sideway_error(output_ref['pose_rotmat'])
        loss_dict_temp['pose_tail_tors'] = tail_torsion_error(output_ref['pose_rotmat'])
        loss_dict_temp['pose_spine_side'] = spine_sideway_error(output_ref['pose_rotmat'])
        loss_dict_temp['pose_spine_tors'] = spine_torsion_error(output_ref['pose_rotmat'])

        # loss to predict ground contact per vertex
        # import pdb; pdb.set_trace()
        if 'gc_vertexwise' in weight_dict_ref.keys():
            # import pdb; pdb.set_trace()
            device = output_ref['vertexwise_ground_contact'].device
            pred_gc = output_ref['vertexwise_ground_contact']
            loss_dict_temp['gc_vertexwise'] = self.gc_loss_on_mesh(pred_gc, target_dict['gc'].to(device=device, dtype=torch.long), target_dict['has_gc'], loss_type_gcmesh='ce')

        keep_smal_mesh = False 
        if 'gc_plane' in weight_dict_ref.keys():
            if weight_dict_ref['gc_plane'] > 0:
                if keep_smal_mesh:
                    target_gc_class = target_dict['gc'][:, :, 0]
                    gc_errors_plane = calculate_plane_errors_batch(output_ref['vertices_smal'], target_gc_class, target_dict['has_gc'], target_dict['has_gc_is_touching'])
                    loss_dict_temp['gc_plane'] = torch.mean(gc_errors_plane)
                else:   # use a uniformly sampled mesh
                    target_gc_class = target_dict['gc'][:, :, 0]
                    device = output_ref['vertices_smal'].device
                    remeshing_relevant_faces = self.remeshing_relevant_faces.to(device)
                    remeshing_relevant_barys = self.remeshing_relevant_barys.to(device)

                    bs = output_ref['vertices_smal'].shape[0]
                    # verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, output_ref['vertices_smal'][:, self.remeshing_relevant_faces])
                    # sel_verts_comparison = output_ref['vertices_smal'][:, self.remeshing_relevant_faces]
                    # verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts_comparison)
                    sel_verts = torch.index_select(output_ref['vertices_smal'], dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((bs, remeshing_relevant_faces.shape[0], 3, 3))
                    verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts)
                    target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, self.remeshing_relevant_faces].to(device=device, dtype=torch.float32))
                    target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long)
                    gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, target_dict['has_gc'], target_dict['has_gc_is_touching'])
                    loss_dict_temp['gc_plane'] = torch.mean(gc_errors_plane) 
                    loss_dict_temp['gc_blowplane'] = torch.mean(gc_errors_under_plane)

        # error on classification if the ground plane is flat
        if 'gc_isflat' in weight_dict_ref.keys():
            # import pdb; pdb.set_trace()
            self.criterion_class_isflat.to(device)
            loss_dict_temp['gc_isflat'] = self.criterion_class(output_ref['isflat'], target_dict['isflat'].to(device))

        # if we refine the shape WITHIN the refinement newtork (shaperef_type is not inexistent)
        # shape regularization
        #   'smal': loss on betas (pca coefficients), betas should be close to 0
        #   'limbs...' loss on selected betas_limbs
        device = output_ref_comp['ref_trans_notnorm'].device
        loss_shape_weighted_list = [torch.zeros((1), device=device).mean()]  
        if 'shape_options' in weight_dict_ref.keys():
            for ind_sp, sp in enumerate(weight_dict_ref['shape_options']):
                weight_sp = weight_dict_ref['shape'][ind_sp]
                # self.logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'] 
                if sp == 'smal':
                    loss_shape_tmp = self.shape_prior(output_ref['betas'])
                elif sp == 'limbs':
                    loss_shape_tmp = torch.mean((output_ref['betas_limbs'])**2)  
                elif sp == 'limbs7':
                    limb_coeffs_list = [0.01, 1, 0.1, 1, 1, 0.1, 2]
                    limb_coeffs = torch.tensor(limb_coeffs_list).to(torch.float32).to(target_dict['tpts'].device)   
                    loss_shape_tmp = torch.mean((output_ref['betas_limbs'] * limb_coeffs[None, :])**2)            
                else:
                    raise NotImplementedError
                loss_shape_weighted_list.append(weight_sp * loss_shape_tmp)
        loss_shape_weighted = torch.stack((loss_shape_weighted_list)).sum()





        # 3D loss for dogs for which we have a unity model or toy figure
        loss_dict_temp['models3d'] = torch.zeros((1), device=device).mean().to(output_ref['betas'].device)
        if 'models3d' in weight_dict_ref.keys():
            if weight_dict_ref['models3d'] > 0:
                assert (self.dog_betas_unity is not None)
                if weight_dict_ref['models3d'] > 0:
                    for ind_dog in range(target_dict['breed_index'].shape[0]):
                        breed_index = np.asscalar(target_dict['breed_index'][ind_dog].detach().cpu().numpy())
                        if breed_index in self.dog_betas_unity.keys():
                            betas_target = self.dog_betas_unity[breed_index][:output_ref['betas'].shape[1]].to(output_ref['betas'].device)
                            betas_output = output_ref['betas'][ind_dog, :]
                            betas_limbs_output = output_ref['betas_limbs'][ind_dog, :]
                            loss_dict_temp['models3d'] += ((betas_limbs_output**2).sum() + ((betas_output-betas_target)**2).sum()) / (output_ref['betas'].shape[1] + output_ref['betas_limbs'].shape[1])
            else:
                weight_dict_ref['models3d'] = 0.0
        else:
            weight_dict_ref['models3d'] = 0.0











        # weight the losses
        loss = torch.zeros((1)).mean().to(device=output_ref['keyp_2d'].device, dtype=output_ref['keyp_2d'].dtype)
        loss_dict = {}
        for loss_name in weight_dict_ref.keys():
            if not loss_name in ['shape', 'shape_options']:
                if weight_dict_ref[loss_name] > 0:
                    loss_weighted = loss_dict_temp[loss_name] * weight_dict_ref[loss_name]
                    loss_dict[loss_name] = loss_weighted.item()
                    loss += loss_weighted
        loss += loss_shape_weighted
        loss_dict['loss'] = loss.item()

        return loss, loss_dict