ac5113 commited on
Commit
8d96dcc
·
1 Parent(s): b9f4bb1

cleared files

Browse files
inference.py DELETED
@@ -1,192 +0,0 @@
1
- import torch
2
- import os
3
- import glob
4
- import argparse
5
- import numpy as np
6
- import cv2
7
- import PIL.Image as pil_img
8
- from loguru import logger
9
- import shutil
10
-
11
- import trimesh
12
- import pyrender
13
-
14
- from models.deco import DECO
15
- from common import constants
16
-
17
- os.environ['PYOPENGL_PLATFORM'] = 'egl'
18
-
19
- if torch.cuda.is_available():
20
- device = torch.device('cuda')
21
- else:
22
- device = torch.device('cpu')
23
-
24
- def initiate_model(args):
25
- deco_model = DECO('hrnet', True, device)
26
-
27
- logger.info(f'Loading weights from {args.model_path}')
28
- checkpoint = torch.load(args.model_path)
29
- deco_model.load_state_dict(checkpoint['deco'], strict=True)
30
-
31
- deco_model.eval()
32
-
33
- return deco_model
34
-
35
- def render_image(scene, img_res, img=None, viewer=False):
36
- '''
37
- Render the given pyrender scene and return the image. Can also overlay the mesh on an image.
38
- '''
39
- if viewer:
40
- pyrender.Viewer(scene, use_raymond_lighting=True)
41
- return 0
42
- else:
43
- r = pyrender.OffscreenRenderer(viewport_width=img_res,
44
- viewport_height=img_res,
45
- point_size=1.0)
46
- color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)
47
- color = color.astype(np.float32) / 255.0
48
-
49
- if img is not None:
50
- valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
51
- input_img = img.detach().cpu().numpy()
52
- output_img = (color[:, :, :-1] * valid_mask +
53
- (1 - valid_mask) * input_img)
54
- else:
55
- output_img = color
56
- return output_img
57
-
58
- def create_scene(mesh, img, focal_length=500, camera_center=250, img_res=500):
59
- # Setup the scene
60
- scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0, 1.0],
61
- ambient_light=(0.3, 0.3, 0.3))
62
- # add mesh for camera
63
- camera_pose = np.eye(4)
64
- camera_rotation = np.eye(3, 3)
65
- camera_translation = np.array([0., 0, 2.5])
66
- camera_pose[:3, :3] = camera_rotation
67
- camera_pose[:3, 3] = camera_rotation @ camera_translation
68
- pyrencamera = pyrender.camera.IntrinsicsCamera(
69
- fx=focal_length, fy=focal_length,
70
- cx=camera_center, cy=camera_center)
71
- scene.add(pyrencamera, pose=camera_pose)
72
- # create and add light
73
- light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=1)
74
- light_pose = np.eye(4)
75
- for lp in [[1, 1, 1], [-1, 1, 1], [1, -1, 1], [-1, -1, 1]]:
76
- light_pose[:3, 3] = mesh.vertices.mean(0) + np.array(lp)
77
- # out_mesh.vertices.mean(0) + np.array(lp)
78
- scene.add(light, pose=light_pose)
79
- # add body mesh
80
- material = pyrender.MetallicRoughnessMaterial(
81
- metallicFactor=0.0,
82
- alphaMode='OPAQUE',
83
- baseColorFactor=(1.0, 1.0, 0.9, 1.0))
84
- mesh_images = []
85
-
86
- # resize input image to fit the mesh image height
87
- img_height = img_res
88
- img_width = int(img_height * img.shape[1] / img.shape[0])
89
- img = cv2.resize(img, (img_width, img_height))
90
- mesh_images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
91
-
92
- for sideview_angle in [0, 90, 180, 270]:
93
- out_mesh = mesh.copy()
94
- rot = trimesh.transformations.rotation_matrix(
95
- np.radians(sideview_angle), [0, 1, 0])
96
- out_mesh.apply_transform(rot)
97
- out_mesh = pyrender.Mesh.from_trimesh(
98
- out_mesh,
99
- material=material)
100
- mesh_pose = np.eye(4)
101
- scene.add(out_mesh, pose=mesh_pose, name='mesh')
102
- output_img = render_image(scene, img_res)
103
- output_img = pil_img.fromarray((output_img * 255).astype(np.uint8))
104
- output_img = np.asarray(output_img)[:, :, :3]
105
- mesh_images.append(output_img)
106
- # delete the previous mesh
107
- prev_mesh = scene.get_nodes(name='mesh').pop()
108
- scene.remove_node(prev_mesh)
109
-
110
- # show upside down view
111
- for topview_angle in [90, 270]:
112
- out_mesh = mesh.copy()
113
- rot = trimesh.transformations.rotation_matrix(
114
- np.radians(topview_angle), [1, 0, 0])
115
- out_mesh.apply_transform(rot)
116
- out_mesh = pyrender.Mesh.from_trimesh(
117
- out_mesh,
118
- material=material)
119
- mesh_pose = np.eye(4)
120
- scene.add(out_mesh, pose=mesh_pose, name='mesh')
121
- output_img = render_image(scene, img_res)
122
- output_img = pil_img.fromarray((output_img * 255).astype(np.uint8))
123
- output_img = np.asarray(output_img)[:, :, :3]
124
- mesh_images.append(output_img)
125
- # delete the previous mesh
126
- prev_mesh = scene.get_nodes(name='mesh').pop()
127
- scene.remove_node(prev_mesh)
128
-
129
- # stack images
130
- IMG = np.hstack(mesh_images)
131
- IMG = pil_img.fromarray(IMG)
132
- IMG.thumbnail((3000, 3000))
133
- return IMG
134
-
135
- def main(args):
136
- if os.path.isdir(args.img_src):
137
- images = glob.iglob(args.img_src + '/*', recursive=True)
138
- else:
139
- images = [args.img_src]
140
-
141
- deco_model = initiate_model(args)
142
-
143
- smpl_path = os.path.join(constants.SMPL_MODEL_DIR, 'smpl_neutral_tpose.ply')
144
-
145
- for img_name in images:
146
- img = cv2.imread(img_name)
147
- img = cv2.resize(img, (256, 256), cv2.INTER_CUBIC)
148
- img = img.transpose(2,0,1)/255.0
149
- img = img[np.newaxis,:,:,:]
150
- img = torch.tensor(img, dtype = torch.float32).to(device)
151
-
152
- cont, _, _ = deco_model(img)
153
- cont = cont.detach().cpu().numpy().squeeze()
154
- cont_smpl = []
155
- for indx, i in enumerate(cont):
156
- if i >= 0.5:
157
- cont_smpl.append(indx)
158
-
159
- img = img.detach().cpu().numpy()
160
- img = np.transpose(img[0], (1, 2, 0))
161
- img = img * 255
162
- img = img.astype(np.uint8)
163
-
164
- contact_smpl = np.zeros((1, 1, 6890))
165
- contact_smpl[0][0][cont_smpl] = 1
166
-
167
- body_model_smpl = trimesh.load(smpl_path, process=False)
168
- for vert in range(body_model_smpl.visual.vertex_colors.shape[0]):
169
- body_model_smpl.visual.vertex_colors[vert] = args.mesh_colour
170
- body_model_smpl.visual.vertex_colors[cont_smpl] = args.annot_colour
171
-
172
- rend = create_scene(body_model_smpl, img)
173
- os.makedirs(os.path.join(args.out_dir, 'Renders'), exist_ok=True)
174
- rend.save(os.path.join(args.out_dir, 'Renders', os.path.basename(img_name).split('.')[0] + '.png'))
175
-
176
- out_dir = os.path.join(args.out_dir, 'Preds', os.path.basename(img_name).split('.')[0])
177
- os.makedirs(out_dir, exist_ok=True)
178
-
179
- logger.info(f'Saving mesh to {out_dir}')
180
- shutil.copyfile(img_name, os.path.join(out_dir, os.path.basename(img_name)))
181
- body_model_smpl.export(os.path.join(out_dir, 'pred.obj'))
182
-
183
- if __name__=='__main__':
184
- parser = argparse.ArgumentParser()
185
- parser.add_argument('--img_src', help='Source of image(s). Can be file or directory', default='./demo_out', type=str)
186
- parser.add_argument('--out_dir', help='Where to store images', default='./demo_out', type=str)
187
- parser.add_argument('--model_path', help='Path to best model weights', default='./checkpoints/Release_Checkpoint/deco_best.pth', type=str)
188
- parser.add_argument('--mesh_colour', help='Colour of the mesh', nargs='+', type=int, default=[130, 130, 130, 255])
189
- parser.add_argument('--annot_colour', help='Colour of the mesh', nargs='+', type=int, default=[0, 255, 0, 255])
190
- args = parser.parse_args()
191
-
192
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tester.py DELETED
@@ -1,61 +0,0 @@
1
- import torch
2
- from torch.utils.data import DataLoader
3
- from loguru import logger
4
-
5
- from train.trainer_step import TrainStepper
6
- from train.base_trainer import evaluator
7
- from data.base_dataset import BaseDataset
8
- from models.deco import DECO
9
- from utils.config import parse_args, run_grid_search_experiments
10
-
11
- def test(hparams):
12
- deco_model = DECO(hparams.TRAINING.ENCODER, hparams.TRAINING.CONTEXT, device)
13
- pytorch_total_params = sum(p.numel() for p in deco_model.parameters() if p.requires_grad)
14
- print('Total number of trainable parameters: ', pytorch_total_params)
15
-
16
- solver = TrainStepper(deco_model, hparams.TRAINING.CONTEXT, hparams.OPTIMIZER.LR, hparams.TRAINING.LOSS_WEIGHTS, hparams.TRAINING.PAL_LOSS_WEIGHTS, device)
17
-
18
- logger.info(f'Loading weights from {hparams.TRAINING.BEST_MODEL_PATH}')
19
- _, _ = solver.load(hparams.TRAINING.BEST_MODEL_PATH)
20
-
21
- # Run testing
22
- for test_loader in val_loaders:
23
- dataset_name = test_loader.dataset.dataset
24
- test_dict, total_time = evaluator(test_loader, solver, hparams, 0, dataset_name, return_dict=True)
25
-
26
- print('Test Contact Precision: ', test_dict['cont_precision'])
27
- print('Test Contact Recall: ', test_dict['cont_recall'])
28
- print('Test Contact F1 Score: ', test_dict['cont_f1'])
29
- print('Test Contact FP Geo. Error: ', test_dict['fp_geo_err'])
30
- print('Test Contact FN Geo. Error: ', test_dict['fn_geo_err'])
31
- if hparams.TRAINING.CONTEXT:
32
- print('Test Contact Semantic Segmentation IoU: ', test_dict['sem_iou'])
33
- print('Test Contact Part Segmentation IoU: ', test_dict['part_iou'])
34
- print('\nTime taken per image for evaluation: ', total_time)
35
- print('-'*50)
36
-
37
- if __name__ == '__main__':
38
- args = parse_args()
39
- hparams = run_grid_search_experiments(
40
- args,
41
- script='tester.py',
42
- change_wt_name=False
43
- )
44
-
45
- if torch.cuda.is_available():
46
- device = torch.device('cuda')
47
- else:
48
- device = torch.device('cpu')
49
-
50
- val_datasets = []
51
- for ds in hparams.VALIDATION.DATASETS:
52
- if ds in ['rich', 'prox']:
53
- val_datasets.append(BaseDataset(ds, 'val', model_type='smplx', normalize=hparams.DATASET.NORMALIZE_IMAGES))
54
- elif ds in ['damon']:
55
- val_datasets.append(BaseDataset(ds, 'val', model_type='smpl', normalize=hparams.DATASET.NORMALIZE_IMAGES))
56
- else:
57
- raise ValueError('Dataset not supported')
58
-
59
- val_loaders = [DataLoader(val_dataset, batch_size=hparams.DATASET.BATCH_SIZE, shuffle=False, num_workers=hparams.DATASET.NUM_WORKERS) for val_dataset in val_datasets]
60
-
61
- test(hparams)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py DELETED
@@ -1,94 +0,0 @@
1
- import torch
2
- from torch.utils.data import DataLoader
3
- import os
4
-
5
- from train.trainer_step import TrainStepper
6
- from train.base_trainer import trainer, evaluator
7
- from data.base_dataset import BaseDataset
8
- from data.mixed_dataset import MixedDataset
9
- from models.deco import DECO
10
- from utils.config import parse_args, run_grid_search_experiments
11
-
12
- def train(hparams):
13
- deco_model = DECO(hparams.TRAINING.ENCODER, hparams.TRAINING.CONTEXT, device)
14
-
15
- solver = TrainStepper(deco_model, hparams.TRAINING.CONTEXT, hparams.OPTIMIZER.LR, hparams.TRAINING.LOSS_WEIGHTS, hparams.TRAINING.PAL_LOSS_WEIGHTS, device)
16
-
17
- vb_f1 = 0
18
- start_ep = 0
19
- num = 0
20
- k = True
21
- latest_model_path = hparams.TRAINING.BEST_MODEL_PATH.replace('best', 'latest')
22
- if os.path.exists(latest_model_path):
23
- _, vb_f1 = solver.load(hparams.TRAINING.BEST_MODEL_PATH)
24
- start_ep, _ = solver.load(latest_model_path)
25
-
26
- for epoch in range(start_ep+1, hparams.TRAINING.NUM_EPOCHS + 1):
27
- # Train one epoch
28
- trainer(epoch, train_loader, solver, hparams)
29
- # Run evaluation
30
- vc_f1 = None
31
- for val_loader in val_loaders:
32
- dataset_name = val_loader.dataset.dataset
33
- vc_f1_ds = evaluator(val_loader, solver, hparams, epoch, dataset_name, normalize=hparams.DATASET.NORMALIZE_IMAGES)
34
- if dataset_name == hparams.VALIDATION.MAIN_DATASET:
35
- vc_f1 = vc_f1_ds
36
- if vc_f1 is None:
37
- raise ValueError('Main dataset not found in validation datasets')
38
-
39
- print('Learning rate: ', solver.lr)
40
-
41
- print('---------------------------------------------')
42
- print('---------------------------------------------')
43
-
44
- solver.save(epoch, vc_f1, latest_model_path)
45
-
46
- if epoch % hparams.TRAINING.CHECKPOINT_EPOCHS == 0:
47
- inter_model_path = latest_model_path.replace('latest', 'epoch_'+str(epoch).zfill(3))
48
- solver.save(epoch, vc_f1, inter_model_path)
49
-
50
- if vc_f1 < vb_f1:
51
- num += 1
52
- print('Not Saving model: Best Val F1 = ', vb_f1, ' Current Val F1 = ', vc_f1)
53
- else:
54
- num = 0
55
- vb_f1 = vc_f1
56
- print('Saving model...')
57
- solver.save(epoch, vb_f1, hparams.TRAINING.BEST_MODEL_PATH)
58
-
59
- if num >= hparams.OPTIMIZER.NUM_UPDATE_LR: solver.update_lr()
60
- if num >= hparams.TRAINING.NUM_EARLY_STOP:
61
- print('Early Stop')
62
- k = False
63
-
64
- if k: continue
65
- else: break
66
-
67
-
68
- if __name__ == '__main__':
69
- args = parse_args()
70
- hparams = run_grid_search_experiments(
71
- args,
72
- script='train.py',
73
- )
74
-
75
- if torch.cuda.is_available():
76
- device = torch.device('cuda')
77
- else:
78
- device = torch.device('cpu')
79
-
80
- train_dataset = MixedDataset(hparams.TRAINING.DATASETS, 'train', dataset_mix_pdf=hparams.TRAINING.DATASET_MIX_PDF, normalize=hparams.DATASET.NORMALIZE_IMAGES)
81
-
82
- val_datasets = []
83
- for ds in hparams.VALIDATION.DATASETS:
84
- if ds in ['rich', 'prox']:
85
- val_datasets.append(BaseDataset(ds, 'val', model_type='smplx', normalize=hparams.DATASET.NORMALIZE_IMAGES))
86
- elif ds in ['damon']:
87
- val_datasets.append(BaseDataset(ds, 'val', model_type='smpl', normalize=hparams.DATASET.NORMALIZE_IMAGES))
88
- else:
89
- raise ValueError('Dataset not supported')
90
-
91
- train_loader = DataLoader(train_dataset, hparams.DATASET.BATCH_SIZE, shuffle=True, num_workers=hparams.DATASET.NUM_WORKERS)
92
- val_loaders = [DataLoader(val_dataset, batch_size=hparams.DATASET.BATCH_SIZE, shuffle=False, num_workers=hparams.DATASET.NUM_WORKERS) for val_dataset in val_datasets]
93
-
94
- train(hparams)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train/__init__.py DELETED
File without changes
train/base_trainer.py DELETED
@@ -1,103 +0,0 @@
1
- from tqdm import tqdm
2
- from utils.metrics import metric, precision_recall_f1score, det_error_metric
3
- import torch
4
- import numpy as np
5
- from vis.visualize import gen_render
6
-
7
-
8
- def trainer(epoch, train_loader, solver, hparams, compute_metrics=False):
9
-
10
- total_epochs = hparams.TRAINING.NUM_EPOCHS
11
- print('Training Epoch {}/{}'.format(epoch, total_epochs))
12
-
13
- length = len(train_loader)
14
- iterator = tqdm(enumerate(train_loader), total=length, leave=False, desc=f'Training Epoch: {epoch}/{total_epochs}')
15
- for step, batch in iterator:
16
- losses, output = solver.optimize(batch)
17
- return losses, output
18
-
19
- @torch.no_grad()
20
- def evaluator(val_loader, solver, hparams, epoch=0, dataset_name='Unknown', normalize=True, return_dict=False):
21
- total_epochs = hparams.TRAINING.NUM_EPOCHS
22
-
23
- batch_size = val_loader.batch_size
24
- dataset_size = len(val_loader.dataset)
25
- print(f'Dataset size: {dataset_size}')
26
-
27
- val_epoch_cont_pre = np.zeros(dataset_size)
28
- val_epoch_cont_rec = np.zeros(dataset_size)
29
- val_epoch_cont_f1 = np.zeros(dataset_size)
30
- val_epoch_fp_geo_err = np.zeros(dataset_size)
31
- val_epoch_fn_geo_err = np.zeros(dataset_size)
32
- if hparams.TRAINING.CONTEXT:
33
- val_epoch_sem_iou = np.zeros(dataset_size)
34
- val_epoch_part_iou = np.zeros(dataset_size)
35
-
36
- val_epoch_cont_loss = np.zeros(dataset_size)
37
-
38
- total_time = 0
39
-
40
- rend_images = []
41
-
42
- eval_dict = {}
43
-
44
- length = len(val_loader)
45
- iterator = tqdm(enumerate(val_loader), total=length, leave=False, desc=f'Evaluating {dataset_name.capitalize()} Epoch: {epoch}/{total_epochs}')
46
- for step, batch in iterator:
47
- curr_batch_size = batch['img'].shape[0]
48
- losses, output, time_taken = solver.evaluate(batch)
49
-
50
- val_epoch_cont_loss[step * batch_size:step * batch_size + curr_batch_size] = losses['cont_loss'].cpu().numpy()
51
-
52
- # compute metrics
53
- contact_labels_3d = output['contact_labels_3d_gt']
54
- has_contact_3d = output['has_contact_3d']
55
- # check if any value in has_contact_3d tensor is 0
56
- assert torch.any(has_contact_3d == 0) == False, 'has_contact_3d tensor has 0 values'
57
-
58
- contact_labels_3d_pred = output['contact_labels_3d_pred']
59
- if hparams.TRAINING.CONTEXT:
60
- sem_mask_gt = output['sem_mask_gt']
61
- sem_seg_pred = output['sem_mask_pred']
62
- part_mask_gt = output['part_mask_gt']
63
- part_seg_pred = output['part_mask_pred']
64
-
65
- cont_pre, cont_rec, cont_f1 = precision_recall_f1score(contact_labels_3d, contact_labels_3d_pred)
66
- fp_geo_err, fn_geo_err = det_error_metric(contact_labels_3d_pred, contact_labels_3d)
67
- if hparams.TRAINING.CONTEXT:
68
- sem_iou = metric(sem_mask_gt, sem_seg_pred)
69
- part_iou = metric(part_mask_gt, part_seg_pred)
70
-
71
- val_epoch_cont_pre[step * batch_size:step * batch_size + curr_batch_size] = cont_pre.cpu().numpy()
72
- val_epoch_cont_rec[step * batch_size:step * batch_size + curr_batch_size] = cont_rec.cpu().numpy()
73
- val_epoch_cont_f1[step * batch_size:step * batch_size + curr_batch_size] = cont_f1.cpu().numpy()
74
- val_epoch_fp_geo_err[step * batch_size:step * batch_size + curr_batch_size] = fp_geo_err.cpu().numpy()
75
- val_epoch_fn_geo_err[step * batch_size:step * batch_size + curr_batch_size] = fn_geo_err.cpu().numpy()
76
- if hparams.TRAINING.CONTEXT:
77
- val_epoch_sem_iou[step * batch_size:step * batch_size + curr_batch_size] = sem_iou.cpu().numpy()
78
- val_epoch_part_iou[step * batch_size:step * batch_size + curr_batch_size] = part_iou.cpu().numpy()
79
-
80
- total_time += time_taken
81
-
82
- # logging every summary_steps steps
83
- if step % hparams.VALIDATION.SUMMARY_STEPS == 0:
84
- if hparams.TRAINING.CONTEXT:
85
- rend = gen_render(output, normalize)
86
- rend_images.append(rend)
87
-
88
- eval_dict['cont_precision'] = np.sum(val_epoch_cont_pre) / dataset_size
89
- eval_dict['cont_recall'] = np.sum(val_epoch_cont_rec) / dataset_size
90
- eval_dict['cont_f1'] = np.sum(val_epoch_cont_f1) / dataset_size
91
- eval_dict['fp_geo_err'] = np.sum(val_epoch_fp_geo_err) / dataset_size
92
- eval_dict['fn_geo_err'] = np.sum(val_epoch_fn_geo_err) / dataset_size
93
- if hparams.TRAINING.CONTEXT:
94
- eval_dict['sem_iou'] = np.sum(val_epoch_sem_iou) / dataset_size
95
- eval_dict['part_iou'] = np.sum(val_epoch_part_iou) / dataset_size
96
- eval_dict['images'] = rend_images
97
-
98
- total_time /= dataset_size
99
-
100
- val_epoch_cont_loss = np.sum(val_epoch_cont_loss) / dataset_size
101
- if return_dict:
102
- return eval_dict, total_time
103
- return eval_dict['cont_f1']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train/trainer_step.py DELETED
@@ -1,291 +0,0 @@
1
- from utils.loss import sem_loss_function, class_loss_function, pixel_anchoring_function
2
- import torch
3
- import os
4
- import time
5
-
6
-
7
- class TrainStepper():
8
- def __init__(self, deco_model, context, learning_rate, loss_weight, pal_loss_weight, device):
9
- self.device = device
10
-
11
- self.model = deco_model
12
- self.context = context
13
-
14
- if self.context:
15
- self.optimizer_sem = torch.optim.Adam(params=list(self.model.encoder_sem.parameters()) + list(self.model.decoder_sem.parameters()),
16
- lr=learning_rate, weight_decay=0.0001)
17
- self.optimizer_part = torch.optim.Adam(
18
- params=list(self.model.encoder_part.parameters()) + list(self.model.decoder_part.parameters()), lr=learning_rate,
19
- weight_decay=0.0001)
20
- self.optimizer_contact = torch.optim.Adam(
21
- params=list(self.model.encoder_sem.parameters()) + list(self.model.encoder_part.parameters()) + list(
22
- self.model.cross_att.parameters()) + list(self.model.classif.parameters()), lr=learning_rate, weight_decay=0.0001)
23
-
24
- if self.context: self.sem_loss = sem_loss_function().to(device)
25
- self.class_loss = class_loss_function().to(device)
26
- self.pixel_anchoring_loss_smplx = pixel_anchoring_function(model_type='smplx').to(device)
27
- self.pixel_anchoring_loss_smpl = pixel_anchoring_function(model_type='smpl').to(device)
28
- self.lr = learning_rate
29
- self.loss_weight = loss_weight
30
- self.pal_loss_weight = pal_loss_weight
31
-
32
- def optimize(self, batch):
33
- self.model.train()
34
-
35
- img_paths = batch['img_path']
36
- img = batch['img'].to(self.device)
37
-
38
- img_scale_factor = batch['img_scale_factor'].to(self.device)
39
-
40
- pose = batch['pose'].to(self.device)
41
- betas = batch['betas'].to(self.device)
42
- transl = batch['transl'].to(self.device)
43
- has_smpl = batch['has_smpl'].to(self.device)
44
- is_smplx = batch['is_smplx'].to(self.device)
45
-
46
- cam_k = batch['cam_k'].to(self.device)
47
-
48
- gt_contact_labels_3d = batch['contact_label_3d'].to(self.device)
49
- has_contact_3d = batch['has_contact_3d'].to(self.device)
50
-
51
- if self.context:
52
- sem_mask_gt = batch['sem_mask'].to(self.device)
53
- part_mask_gt = batch['part_mask'].to(self.device)
54
-
55
- polygon_contact_2d = batch['polygon_contact_2d'].to(self.device)
56
- has_polygon_contact_2d = batch['has_polygon_contact_2d'].to(self.device)
57
-
58
- # Forward pass
59
- if self.context:
60
- cont, sem_mask_pred, part_mask_pred = self.model(img)
61
- else:
62
- cont = self.model(img)
63
-
64
- if self.context:
65
- loss_sem = self.sem_loss(sem_mask_gt, sem_mask_pred)
66
- loss_part = self.sem_loss(part_mask_gt, part_mask_pred)
67
- valid_contact_3d = has_contact_3d
68
- loss_cont = self.class_loss(gt_contact_labels_3d, cont, valid_contact_3d)
69
- valid_polygon_contact_2d = has_polygon_contact_2d
70
-
71
- if self.pal_loss_weight > 0 and (is_smplx == 0).sum() > 0:
72
- smpl_body_params = {'pose': pose[is_smplx == 0], 'betas': betas[is_smplx == 0],
73
- 'transl': transl[is_smplx == 0],
74
- 'has_smpl': has_smpl[is_smplx == 0]}
75
- loss_pix_anchoring_smpl, contact_2d_pred_rgb_smpl, _ = self.pixel_anchoring_loss_smpl(cont[is_smplx == 0],
76
- smpl_body_params,
77
- cam_k[is_smplx == 0],
78
- img_scale_factor[
79
- is_smplx == 0],
80
- polygon_contact_2d[
81
- is_smplx == 0],
82
- valid_polygon_contact_2d[
83
- is_smplx == 0])
84
- # weigh the smpl loss based on the number of smpl sample
85
- loss_pix_anchoring = loss_pix_anchoring_smpl * (is_smplx == 0).sum() / len(is_smplx)
86
- contact_2d_pred_rgb = contact_2d_pred_rgb_smpl
87
- else:
88
- loss_pix_anchoring = 0
89
- contact_2d_pred_rgb = torch.zeros_like(polygon_contact_2d)
90
-
91
- if self.context: loss = loss_sem + loss_part + self.loss_weight * loss_cont + self.pal_loss_weight * loss_pix_anchoring
92
- else: loss = self.loss_weight * loss_cont + self.pal_loss_weight * loss_pix_anchoring
93
-
94
- if self.context:
95
- self.optimizer_sem.zero_grad()
96
- self.optimizer_part.zero_grad()
97
- self.optimizer_contact.zero_grad()
98
-
99
- loss.backward()
100
-
101
- if self.context:
102
- self.optimizer_sem.step()
103
- self.optimizer_part.step()
104
- self.optimizer_contact.step()
105
-
106
- if self.context:
107
- losses = {'sem_loss': loss_sem,
108
- 'part_loss': loss_part,
109
- 'cont_loss': loss_cont,
110
- 'pal_loss': loss_pix_anchoring,
111
- 'total_loss': loss}
112
- else:
113
- losses = {'cont_loss': loss_cont,
114
- 'pal_loss': loss_pix_anchoring,
115
- 'total_loss': loss}
116
-
117
- if self.context:
118
- output = {
119
- 'img': img,
120
- 'sem_mask_gt': sem_mask_gt,
121
- 'sem_mask_pred': sem_mask_pred,
122
- 'part_mask_gt': part_mask_gt,
123
- 'part_mask_pred': part_mask_pred,
124
- 'has_contact_2d': has_polygon_contact_2d,
125
- 'contact_2d_gt': polygon_contact_2d,
126
- 'contact_2d_pred_rgb': contact_2d_pred_rgb,
127
- 'has_contact_3d': has_contact_3d,
128
- 'contact_labels_3d_gt': gt_contact_labels_3d,
129
- 'contact_labels_3d_pred': cont}
130
- else:
131
- output = {
132
- 'img': img,
133
- 'has_contact_2d': has_polygon_contact_2d,
134
- 'contact_2d_gt': polygon_contact_2d,
135
- 'contact_2d_pred_rgb': contact_2d_pred_rgb,
136
- 'has_contact_3d': has_contact_3d,
137
- 'contact_labels_3d_gt': gt_contact_labels_3d,
138
- 'contact_labels_3d_pred': cont}
139
-
140
- return losses, output
141
-
142
- @torch.no_grad()
143
- def evaluate(self, batch):
144
- self.model.eval()
145
-
146
- img_paths = batch['img_path']
147
- img = batch['img'].to(self.device)
148
-
149
- img_scale_factor = batch['img_scale_factor'].to(self.device)
150
-
151
- pose = batch['pose'].to(self.device)
152
- betas = batch['betas'].to(self.device)
153
- transl = batch['transl'].to(self.device)
154
- has_smpl = batch['has_smpl'].to(self.device)
155
- is_smplx = batch['is_smplx'].to(self.device)
156
-
157
- cam_k = batch['cam_k'].to(self.device)
158
-
159
- gt_contact_labels_3d = batch['contact_label_3d'].to(self.device)
160
- has_contact_3d = batch['has_contact_3d'].to(self.device)
161
-
162
- if self.context:
163
- sem_mask_gt = batch['sem_mask'].to(self.device)
164
- part_mask_gt = batch['part_mask'].to(self.device)
165
-
166
- polygon_contact_2d = batch['polygon_contact_2d'].to(self.device)
167
- has_polygon_contact_2d = batch['has_polygon_contact_2d'].to(self.device)
168
-
169
- # Forward pass
170
- initial_time = time.time()
171
- if self.context: cont, sem_mask_pred, part_mask_pred = self.model(img)
172
- else: cont = self.model(img)
173
- time_taken = time.time() - initial_time
174
-
175
- if self.context:
176
- loss_sem = self.sem_loss(sem_mask_gt, sem_mask_pred)
177
- loss_part = self.sem_loss(part_mask_gt, part_mask_pred)
178
- valid_contact_3d = has_contact_3d
179
- loss_cont = self.class_loss(gt_contact_labels_3d, cont, valid_contact_3d)
180
- valid_polygon_contact_2d = has_polygon_contact_2d
181
-
182
- if self.pal_loss_weight > 0 and (is_smplx == 0).sum() > 0: # PAL loss only on 2D contacts in HOT which only has SMPL
183
- smpl_body_params = {'pose': pose[is_smplx == 0], 'betas': betas[is_smplx == 0], 'transl': transl[is_smplx == 0],
184
- 'has_smpl': has_smpl[is_smplx == 0]}
185
- loss_pix_anchoring_smpl, contact_2d_pred_rgb_smpl, _ = self.pixel_anchoring_loss_smpl(cont[is_smplx == 0],
186
- smpl_body_params,
187
- cam_k[is_smplx == 0],
188
- img_scale_factor[
189
- is_smplx == 0],
190
- polygon_contact_2d[
191
- is_smplx == 0],
192
- valid_polygon_contact_2d[
193
- is_smplx == 0])
194
- # weight the smpl loss based on the number of smpl samples
195
- contact_2d_pred_rgb = contact_2d_pred_rgb_smpl
196
- loss_pix_anchoring = loss_pix_anchoring_smpl * (is_smplx == 0).sum() / len(is_smplx)
197
- else:
198
- loss_pix_anchoring = 0
199
- contact_2d_pred_rgb = torch.zeros_like(polygon_contact_2d)
200
-
201
- if self.context: loss = loss_sem + loss_part + self.loss_weight * loss_cont + self.pal_loss_weight * loss_pix_anchoring
202
- else: loss = self.loss_weight * loss_cont + self.pal_loss_weight * loss_pix_anchoring
203
-
204
- if self.context:
205
- losses = {'sem_loss': loss_sem,
206
- 'part_loss': loss_part,
207
- 'cont_loss': loss_cont,
208
- 'pal_loss': loss_pix_anchoring,
209
- 'total_loss': loss}
210
- else:
211
- losses = {'cont_loss': loss_cont,
212
- 'pal_loss': loss_pix_anchoring,
213
- 'total_loss': loss}
214
-
215
- if self.context:
216
- output = {
217
- 'img': img,
218
- 'sem_mask_gt': sem_mask_gt,
219
- 'sem_mask_pred': sem_mask_pred,
220
- 'part_mask_gt': part_mask_gt,
221
- 'part_mask_pred': part_mask_pred,
222
- 'has_contact_2d': has_polygon_contact_2d,
223
- 'contact_2d_gt': polygon_contact_2d,
224
- 'contact_2d_pred_rgb': contact_2d_pred_rgb,
225
- 'has_contact_3d': has_contact_3d,
226
- 'contact_labels_3d_gt': gt_contact_labels_3d,
227
- 'contact_labels_3d_pred': cont}
228
- else:
229
- output = {
230
- 'img': img,
231
- 'has_contact_2d': has_polygon_contact_2d,
232
- 'contact_2d_gt': polygon_contact_2d,
233
- 'contact_2d_pred_rgb': contact_2d_pred_rgb,
234
- 'has_contact_3d': has_contact_3d,
235
- 'contact_labels_3d_gt': gt_contact_labels_3d,
236
- 'contact_labels_3d_pred': cont}
237
-
238
- return losses, output, time_taken
239
-
240
- def save(self, ep, f1, model_path):
241
- # create model directory if it does not exist
242
- os.makedirs(os.path.dirname(model_path), exist_ok=True)
243
- if self.context:
244
- torch.save({
245
- 'epoch': ep,
246
- 'deco': self.model.state_dict(),
247
- 'f1': f1,
248
- 'sem_optim': self.optimizer_sem.state_dict(),
249
- 'part_optim': self.optimizer_part.state_dict(),
250
- 'contact_optim': self.optimizer_contact.state_dict()
251
- },
252
- model_path)
253
- else:
254
- torch.save({
255
- 'epoch': ep,
256
- 'deco': self.model.state_dict(),
257
- 'f1': f1,
258
- 'sem_optim': self.optimizer_sem.state_dict(),
259
- 'part_optim': self.optimizer_part.state_dict(),
260
- 'contact_optim': self.optimizer_contact.state_dict()
261
- },
262
- model_path)
263
-
264
- def load(self, model_path):
265
- print(f'~~~ Loading existing checkpoint from {model_path} ~~~')
266
- checkpoint = torch.load(model_path)
267
- self.model.load_state_dict(checkpoint['deco'], strict=True)
268
-
269
- if self.context:
270
- self.optimizer_sem.load_state_dict(checkpoint['sem_optim'])
271
- self.optimizer_part.load_state_dict(checkpoint['part_optim'])
272
- self.optimizer_contact.load_state_dict(checkpoint['contact_optim'])
273
- epoch = checkpoint['epoch']
274
- f1 = checkpoint['f1']
275
- return epoch, f1
276
-
277
- def update_lr(self, factor=2):
278
- if factor:
279
- new_lr = self.lr / factor
280
-
281
- if self.context:
282
- self.optimizer_sem = torch.optim.Adam(params=list(self.model.encoder_sem.parameters()) + list(self.model.decoder_sem.parameters()),
283
- lr=new_lr, weight_decay=0.0001)
284
- self.optimizer_part = torch.optim.Adam(
285
- params=list(self.model.encoder_part.parameters()) + list(self.model.decoder_part.parameters()), lr=new_lr, weight_decay=0.0001)
286
- self.optimizer_contact = torch.optim.Adam(
287
- params=list(self.model.encoder_sem.parameters()) + list(self.model.encoder_part.parameters()) + list(
288
- self.model.cross_att.parameters()) + list(self.model.classif.parameters()), lr=new_lr, weight_decay=0.0001)
289
-
290
- print('update learning rate: %f -> %f' % (self.lr, new_lr))
291
- self.lr = new_lr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/__init__.py DELETED
File without changes
utils/cluster.py DELETED
@@ -1,99 +0,0 @@
1
- import os
2
- import sys
3
- import stat
4
- import shutil
5
- import subprocess
6
-
7
- from loguru import logger
8
-
9
- GPUS = {
10
- 'v100-v16': ('\"Tesla V100-PCIE-16GB\"', 'tesla', 16000),
11
- 'v100-p32': ('\"Tesla V100-PCIE-32GB\"', 'tesla', 32000),
12
- 'v100-s32': ('\"Tesla V100-SXM2-32GB\"', 'tesla', 32000),
13
- 'v100-p16': ('\"Tesla P100-PCIE-16GB\"', 'tesla', 16000),
14
- }
15
-
16
- def get_gpus(min_mem=10000, arch=('tesla', 'quadro', 'rtx')):
17
- gpu_names = []
18
- for k, (gpu_name, gpu_arch, gpu_mem) in GPUS.items():
19
- if gpu_mem >= min_mem and gpu_arch in arch:
20
- gpu_names.append(gpu_name)
21
-
22
- assert len(gpu_names) > 0, 'Suitable GPU model could not be found'
23
-
24
- return gpu_names
25
-
26
-
27
- def execute_task_on_cluster(
28
- script,
29
- exp_name,
30
- output_dir,
31
- condor_dir,
32
- cfg_file,
33
- num_exp=1,
34
- exp_opts=None,
35
- bid_amount=10,
36
- num_workers=2,
37
- memory=64000,
38
- gpu_min_mem=10000,
39
- gpu_arch=('tesla', 'quadro', 'rtx'),
40
- num_gpus=1
41
- ):
42
- # copy config to a new experiment directory and source from there.
43
- # this makes sure the correct config is copied even if you change the config file
44
- # after starting the experiment and before the first job is submitted
45
- temp_config_dir = os.path.join(os.path.dirname(condor_dir), 'temp_configs', exp_name)
46
- os.makedirs(temp_config_dir, exist_ok=True)
47
- new_cfg_file = os.path.join(temp_config_dir, 'config.yaml')
48
- shutil.copy(src=cfg_file, dst=new_cfg_file)
49
-
50
- gpus = get_gpus(min_mem=gpu_min_mem, arch=gpu_arch)
51
-
52
- gpus = ' || '.join([f'CUDADeviceName=={x}' for x in gpus])
53
-
54
- condor_log_dir = os.path.join(condor_dir, 'condorlog', exp_name)
55
- os.makedirs(condor_log_dir, exist_ok=True)
56
- submission = f'executable = {condor_log_dir}/{exp_name}_run.sh\n' \
57
- 'arguments = $(Process) $(Cluster)\n' \
58
- f'error = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).err\n' \
59
- f'output = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).out\n' \
60
- f'log = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).log\n' \
61
- f'request_memory = {memory}\n' \
62
- f'request_cpus={int(num_workers)}\n' \
63
- f'request_gpus={num_gpus}\n' \
64
- f'requirements={gpus}\n' \
65
- f'+MaxRunningPrice = 500\n' \
66
- f'queue {num_exp}'
67
- # f'request_cpus={int(num_workers/2)}\n' \
68
- # f'+RunningPriceExceededAction = \"kill\"\n' \
69
- print('<<< Condor Submission >>> ')
70
- print(submission)
71
-
72
- with open(f'{condor_log_dir}/{exp_name}_submit.sub', 'w') as f:
73
- f.write(submission)
74
-
75
- # output_dir = os.path.join(output_dir, exp_name)
76
- logger.info(f'The logs for this experiments can be found under: {condor_log_dir}')
77
- logger.info(f'The outputs for this experiments can be found under: {output_dir}')
78
- ## This is the trick. Notice there is no --cluster here
79
- bash = 'export PYTHONBUFFERED=1\n export PATH=$PATH\n ' \
80
- f'{sys.executable} {script} --cfg {new_cfg_file} --cfg_id $1'
81
-
82
- if exp_opts is not None:
83
- bash += ' --opts '
84
- for opt in exp_opts:
85
- bash += f'{opt} '
86
- bash += 'SYSTEM.CLUSTER_NODE $2.$1'
87
- else:
88
- bash += ' --opts SYSTEM.CLUSTER_NODE $2.$1'
89
-
90
- executable_path = f'{condor_log_dir}/{exp_name}_run.sh'
91
-
92
- with open(executable_path, 'w') as f:
93
- f.write(bash)
94
-
95
- os.chmod(executable_path, stat.S_IRWXU)
96
-
97
- cmd = ['condor_submit_bid', f'{bid_amount}', f'{condor_log_dir}/{exp_name}_submit.sub']
98
- logger.info('Executing ' + ' '.join(cmd))
99
- subprocess.call(cmd)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/colorwheel.py DELETED
@@ -1,22 +0,0 @@
1
- import cv2
2
- import numpy as np
3
-
4
-
5
- def make_color_wheel_image(img_width, img_height):
6
- """
7
- Creates a color wheel based image of given width and height
8
- Args:
9
- img_width (int):
10
- img_height (int):
11
-
12
- Returns:
13
- opencv image (numpy array): color wheel based image
14
- """
15
- hue = np.fromfunction(lambda i, j: (np.arctan2(i-img_height/2, img_width/2-j) + np.pi)*(180/np.pi)/2,
16
- (img_height, img_width), dtype=np.float)
17
- saturation = np.ones((img_height, img_width)) * 255
18
- value = np.ones((img_height, img_width)) * 255
19
- hsl = np.dstack((hue, saturation, value))
20
- color_map = cv2.cvtColor(np.array(hsl, dtype=np.uint8), cv2.COLOR_HSV2BGR)
21
- return color_map
22
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/config.py DELETED
@@ -1,196 +0,0 @@
1
- import itertools
2
- import operator
3
- import os
4
- import shutil
5
- import time
6
- from functools import reduce
7
- from typing import List, Union
8
-
9
- import configargparse
10
- import yaml
11
- from flatten_dict import flatten, unflatten
12
- from loguru import logger
13
- from yacs.config import CfgNode as CN
14
-
15
- from utils.cluster import execute_task_on_cluster
16
- from utils.default_hparams import hparams
17
-
18
-
19
- def parse_args():
20
- def add_common_cmdline_args(parser):
21
- # for cluster runs
22
- parser.add_argument('--cfg', required=True, type=str, help='cfg file path')
23
- parser.add_argument('--opts', default=[], nargs='*', help='additional options to update config')
24
- parser.add_argument('--cfg_id', type=int, default=0, help='cfg id to run when multiple experiments are spawned')
25
- parser.add_argument('--cluster', default=False, action='store_true', help='creates submission files for cluster')
26
- parser.add_argument('--bid', type=int, default=10, help='amount of bid for cluster')
27
- parser.add_argument('--memory', type=int, default=64000, help='memory amount for cluster')
28
- parser.add_argument('--gpu_min_mem', type=int, default=12000, help='minimum amount of GPU memory')
29
- parser.add_argument('--gpu_arch', default=['tesla', 'quadro', 'rtx'],
30
- nargs='*', help='additional options to update config')
31
- parser.add_argument('--num_cpus', type=int, default=8, help='num cpus for cluster')
32
- return parser
33
-
34
- # For Blender main parser
35
- arg_formatter = configargparse.ArgumentDefaultsHelpFormatter
36
- cfg_parser = configargparse.YAMLConfigFileParser
37
- description = 'PyTorch implementation of DECO'
38
-
39
- parser = configargparse.ArgumentParser(formatter_class=arg_formatter,
40
- config_file_parser_class=cfg_parser,
41
- description=description,
42
- prog='deco')
43
-
44
- parser = add_common_cmdline_args(parser)
45
-
46
- args = parser.parse_args()
47
- print(args, end='\n\n')
48
-
49
- return args
50
-
51
- def get_hparams_defaults():
52
- """Get a yacs hparamsNode object with default values for my_project."""
53
- # Return a clone so that the defaults will not be altered
54
- # This is for the "local variable" use pattern
55
- return hparams.clone()
56
-
57
- def update_hparams(hparams_file):
58
- hparams = get_hparams_defaults()
59
- hparams.merge_from_file(hparams_file)
60
- return hparams.clone()
61
-
62
- def update_hparams_from_dict(cfg_dict):
63
- hparams = get_hparams_defaults()
64
- cfg = hparams.load_cfg(str(cfg_dict))
65
- hparams.merge_from_other_cfg(cfg)
66
- return hparams.clone()
67
-
68
- def get_grid_search_configs(config, excluded_keys=[]):
69
- """
70
- :param config: dictionary with the configurations
71
- :return: The different configurations
72
- """
73
-
74
- def bool_to_string(x: Union[List[bool], bool]) -> Union[List[str], str]:
75
- """
76
- boolean to string conversion
77
- :param x: list or bool to be converted
78
- :return: string converted thinghat
79
- """
80
- if isinstance(x, bool):
81
- return [str(x)]
82
- for i, j in enumerate(x):
83
- x[i] = str(j)
84
- return x
85
-
86
- # exclude from grid search
87
-
88
- flattened_config_dict = flatten(config, reducer='path')
89
- hyper_params = []
90
-
91
- for k,v in flattened_config_dict.items():
92
- if isinstance(v,list):
93
- if k in excluded_keys:
94
- flattened_config_dict[k] = ['+'.join(v)]
95
- elif len(v) > 1:
96
- hyper_params += [k]
97
-
98
- if isinstance(v, list) and isinstance(v[0], bool) :
99
- flattened_config_dict[k] = bool_to_string(v)
100
-
101
- if not isinstance(v,list):
102
- if isinstance(v, bool):
103
- flattened_config_dict[k] = bool_to_string(v)
104
- else:
105
- flattened_config_dict[k] = [v]
106
-
107
- keys, values = zip(*flattened_config_dict.items())
108
- experiments = [dict(zip(keys, v)) for v in itertools.product(*values)]
109
-
110
- for exp_id, exp in enumerate(experiments):
111
- for param in excluded_keys:
112
- exp[param] = exp[param].strip().split('+')
113
- for param_name, param_value in exp.items():
114
- # print(param_name,type(param_value))
115
- if isinstance(param_value, list) and (param_value[0] in ['True', 'False']):
116
- exp[param_name] = [True if x == 'True' else False for x in param_value]
117
- if param_value in ['True', 'False']:
118
- if param_value == 'True':
119
- exp[param_name] = True
120
- else:
121
- exp[param_name] = False
122
-
123
-
124
- experiments[exp_id] = unflatten(exp, splitter='path')
125
-
126
- return experiments, hyper_params
127
-
128
- def get_from_dict(dict, keys):
129
- return reduce(operator.getitem, keys, dict)
130
-
131
- def save_dict_to_yaml(obj, filename, mode='w'):
132
- with open(filename, mode) as f:
133
- yaml.dump(obj, f, default_flow_style=False)
134
-
135
- def run_grid_search_experiments(
136
- args,
137
- script='train.py',
138
- change_wt_name=True
139
- ):
140
- cfg = yaml.safe_load(open(args.cfg))
141
- # parse config file to split into a list of configs with tuning hyperparameters separated
142
- # Also return the names of tuned hyperparameters hyperparameters
143
- different_configs, hyperparams = get_grid_search_configs(
144
- cfg,
145
- excluded_keys=['TRAINING/DATASETS', 'TRAINING/DATASET_MIX_PDF', 'VALIDATION/DATASETS'],
146
- )
147
- logger.info(f'Grid search hparams: \n {hyperparams}')
148
-
149
- # The config file may be missing some default values, so we need to add them
150
- different_configs = [update_hparams_from_dict(c) for c in different_configs]
151
- logger.info(f'======> Number of experiment configurations is {len(different_configs)}')
152
-
153
- config_to_run = CN(different_configs[args.cfg_id])
154
-
155
- if args.cluster:
156
- execute_task_on_cluster(
157
- script=script,
158
- exp_name=config_to_run.EXP_NAME,
159
- output_dir=config_to_run.OUTPUT_DIR,
160
- condor_dir=config_to_run.CONDOR_DIR,
161
- cfg_file=args.cfg,
162
- num_exp=len(different_configs),
163
- bid_amount=args.bid,
164
- num_workers=config_to_run.DATASET.NUM_WORKERS,
165
- memory=args.memory,
166
- exp_opts=args.opts,
167
- gpu_min_mem=args.gpu_min_mem,
168
- gpu_arch=args.gpu_arch,
169
- )
170
- exit()
171
-
172
- # ==== create logdir using hyperparam settings
173
- logtime = time.strftime('%d-%m-%Y_%H-%M-%S')
174
- logdir = f'{logtime}_{config_to_run.EXP_NAME}'
175
- wt_file = config_to_run.EXP_NAME + '_'
176
- for hp in hyperparams:
177
- v = get_from_dict(different_configs[args.cfg_id], hp.split('/'))
178
- logdir += f'_{hp.replace("/", ".").replace("_", "").lower()}-{v}'
179
- wt_file += f'{hp.replace("/", ".").replace("_", "").lower()}-{v}_'
180
- logdir = os.path.join(config_to_run.OUTPUT_DIR, logdir)
181
- os.makedirs(logdir, exist_ok=True)
182
- config_to_run.LOGDIR = logdir
183
-
184
- wt_file += 'best.pth'
185
- wt_path = os.path.join(os.path.dirname(config_to_run.TRAINING.BEST_MODEL_PATH), wt_file)
186
- if change_wt_name: config_to_run.TRAINING.BEST_MODEL_PATH = wt_path
187
-
188
- shutil.copy(src=args.cfg, dst=os.path.join(logdir, 'config.yaml'))
189
-
190
- # save config
191
- save_dict_to_yaml(
192
- unflatten(flatten(config_to_run)),
193
- os.path.join(config_to_run.LOGDIR, 'config_to_run.yaml')
194
- )
195
-
196
- return config_to_run
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/default_hparams.py DELETED
@@ -1,45 +0,0 @@
1
- from yacs.config import CfgNode as CN
2
-
3
- # Set default hparams to construct new default config
4
- # Make sure the defaults are same as in parser
5
- hparams = CN()
6
-
7
- # General settings
8
- hparams.EXP_NAME = 'default'
9
- hparams.PROJECT_NAME = 'default'
10
- hparams.OUTPUT_DIR = 'deco_results/'
11
- hparams.CONDOR_DIR = '/is/cluster/work/achatterjee/condor/rich/'
12
- hparams.LOGDIR = ''
13
-
14
- # Dataset hparams
15
- hparams.DATASET = CN()
16
- hparams.DATASET.BATCH_SIZE = 64
17
- hparams.DATASET.NUM_WORKERS = 4
18
- hparams.DATASET.NORMALIZE_IMAGES = True
19
-
20
- # Optimizer hparams
21
- hparams.OPTIMIZER = CN()
22
- hparams.OPTIMIZER.TYPE = 'adam'
23
- hparams.OPTIMIZER.LR = 5e-5
24
- hparams.OPTIMIZER.NUM_UPDATE_LR = 10
25
-
26
- # Training hparams
27
- hparams.TRAINING = CN()
28
- hparams.TRAINING.ENCODER = 'hrnet'
29
- hparams.TRAINING.CONTEXT = True
30
- hparams.TRAINING.NUM_EPOCHS = 50
31
- hparams.TRAINING.SUMMARY_STEPS = 100
32
- hparams.TRAINING.CHECKPOINT_EPOCHS = 5
33
- hparams.TRAINING.NUM_EARLY_STOP = 10
34
- hparams.TRAINING.DATASETS = ['rich']
35
- hparams.TRAINING.DATASET_MIX_PDF = ['1.']
36
- hparams.TRAINING.DATASET_ROOT_PATH = '/is/cluster/work/achatterjee/rich/npzs'
37
- hparams.TRAINING.BEST_MODEL_PATH = '/is/cluster/work/achatterjee/weights/rich/exp/rich_exp.pth'
38
- hparams.TRAINING.LOSS_WEIGHTS = 1.
39
- hparams.TRAINING.PAL_LOSS_WEIGHTS = 1.
40
-
41
- # Training hparams
42
- hparams.VALIDATION = CN()
43
- hparams.VALIDATION.SUMMARY_STEPS = 100
44
- hparams.VALIDATION.DATASETS = ['rich']
45
- hparams.VALIDATION.MAIN_DATASET = 'rich'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/diff_renderer.py DELETED
@@ -1,287 +0,0 @@
1
- # from https://gitlab.tuebingen.mpg.de/mkocabas/projects/-/blob/master/pare/pare/utils/diff_renderer.py
2
-
3
- import torch
4
- import numpy as np
5
- import torch.nn as nn
6
-
7
- from pytorch3d.renderer import (
8
- PerspectiveCameras,
9
- RasterizationSettings,
10
- DirectionalLights,
11
- BlendParams,
12
- HardFlatShader,
13
- MeshRasterizer,
14
- TexturesVertex,
15
- TexturesAtlas
16
- )
17
- from pytorch3d.structures import Meshes
18
-
19
- from .image_utils import get_default_camera
20
- from .smpl_uv import get_tenet_texture
21
-
22
-
23
- class MeshRendererWithDepth(nn.Module):
24
- """
25
- A class for rendering a batch of heterogeneous meshes. The class should
26
- be initialized with a rasterizer and shader class which each have a forward
27
- function.
28
- """
29
-
30
- def __init__(self, rasterizer, shader):
31
- super().__init__()
32
- self.rasterizer = rasterizer
33
- self.shader = shader
34
-
35
- def forward(self, meshes_world, **kwargs) -> torch.Tensor:
36
- """
37
- Render a batch of images from a batch of meshes by rasterizing and then
38
- shading.
39
-
40
- NOTE: If the blur radius for rasterization is > 0.0, some pixels can
41
- have one or more barycentric coordinates lying outside the range [0, 1].
42
- For a pixel with out of bounds barycentric coordinates with respect to a
43
- face f, clipping is required before interpolating the texture uv
44
- coordinates and z buffer so that the colors and depths are limited to
45
- the range for the corresponding face.
46
- """
47
- fragments = self.rasterizer(meshes_world, **kwargs)
48
- images = self.shader(fragments, meshes_world, **kwargs)
49
-
50
- mask = (fragments.zbuf > -1).float()
51
-
52
- zbuf = fragments.zbuf.view(images.shape[0], -1)
53
- # print(images.shape, zbuf.shape)
54
- depth = (zbuf - zbuf.min(-1, keepdims=True).values) / \
55
- (zbuf.max(-1, keepdims=True).values - zbuf.min(-1, keepdims=True).values)
56
- depth = depth.reshape(*images.shape[:3] + (1,))
57
-
58
- images = torch.cat([images[:, :, :, :3], mask, depth], dim=-1)
59
- return images
60
-
61
-
62
- class DifferentiableRenderer(nn.Module):
63
- def __init__(
64
- self,
65
- img_h,
66
- img_w,
67
- focal_length,
68
- device='cuda',
69
- background_color=(0.0, 0.0, 0.0),
70
- texture_mode='smplpix',
71
- vertex_colors=None,
72
- face_textures=None,
73
- smpl_faces=None,
74
- is_train=False,
75
- is_cam_batch=False,
76
- ):
77
- super(DifferentiableRenderer, self).__init__()
78
- self.x = 'a'
79
- self.img_h = img_h
80
- self.img_w = img_w
81
- self.device = device
82
- self.focal_length = focal_length
83
- K, R = get_default_camera(focal_length, img_h, img_w, is_cam_batch=is_cam_batch)
84
- K, R = K.to(device), R.to(device)
85
-
86
- # T = torch.tensor([[0, 0, 2.5 * self.focal_length / max(self.img_h, self.img_w)]]).to(device)
87
- if is_cam_batch:
88
- T = torch.zeros((K.shape[0], 3)).to(device)
89
- else:
90
- T = torch.tensor([[0.0, 0.0, 0.0]]).to(device)
91
- self.background_color = background_color
92
- self.renderer = None
93
- smpl_faces = smpl_faces
94
-
95
- if texture_mode == 'smplpix':
96
- face_colors = get_tenet_texture(mode=texture_mode).to(device).float()
97
- vertex_colors = torch.from_numpy(
98
- np.load(f'data/smpl/{texture_mode}_vertex_colors.npy')[:,:3]
99
- ).unsqueeze(0).to(device).float()
100
- if texture_mode == 'partseg':
101
- vertex_colors = vertex_colors[..., :3].unsqueeze(0).to(device)
102
- face_colors = face_textures.to(device)
103
- if texture_mode == 'deco':
104
- vertex_colors = vertex_colors[..., :3].to(device)
105
- face_colors = face_textures.to(device)
106
-
107
- self.register_buffer('K', K)
108
- self.register_buffer('R', R)
109
- self.register_buffer('T', T)
110
- self.register_buffer('face_colors', face_colors)
111
- self.register_buffer('vertex_colors', vertex_colors)
112
- self.register_buffer('smpl_faces', smpl_faces)
113
-
114
- self.set_requires_grad(is_train)
115
-
116
- def set_requires_grad(self, val=False):
117
- self.K.requires_grad_(val)
118
- self.R.requires_grad_(val)
119
- self.T.requires_grad_(val)
120
- self.face_colors.requires_grad_(val)
121
- self.vertex_colors.requires_grad_(val)
122
- # check if smpl_faces is a FloatTensor as requires_grad_ is not defined for LongTensor
123
- if isinstance(self.smpl_faces, torch.FloatTensor):
124
- self.smpl_faces.requires_grad_(val)
125
-
126
- def forward(self, vertices, faces=None, R=None, T=None):
127
- raise NotImplementedError
128
-
129
-
130
- class Pytorch3D(DifferentiableRenderer):
131
- def __init__(
132
- self,
133
- img_h,
134
- img_w,
135
- focal_length,
136
- device='cuda',
137
- background_color=(0.0, 0.0, 0.0),
138
- texture_mode='smplpix',
139
- vertex_colors=None,
140
- face_textures=None,
141
- smpl_faces=None,
142
- model_type='smpl',
143
- is_train=False,
144
- is_cam_batch=False,
145
- ):
146
- super(Pytorch3D, self).__init__(
147
- img_h,
148
- img_w,
149
- focal_length,
150
- device=device,
151
- background_color=background_color,
152
- texture_mode=texture_mode,
153
- vertex_colors=vertex_colors,
154
- face_textures=face_textures,
155
- smpl_faces=smpl_faces,
156
- is_train=is_train,
157
- is_cam_batch=is_cam_batch,
158
- )
159
-
160
- # this R converts the camera from pyrender NDC to
161
- # OpenGL coordinate frame. It is basicall R(180, X) x R(180, Y)
162
- # I manually defined it here for convenience
163
- self.R = self.R @ torch.tensor(
164
- [[[ -1.0, 0.0, 0.0],
165
- [ 0.0, -1.0, 0.0],
166
- [ 0.0, 0.0, 1.0]]],
167
- dtype=self.R.dtype, device=self.R.device,
168
- )
169
-
170
- if is_cam_batch:
171
- focal_length = self.focal_length
172
- else:
173
- focal_length = self.focal_length[None, :]
174
-
175
- principal_point = ((self.img_w // 2, self.img_h // 2),)
176
- image_size = ((self.img_h, self.img_w),)
177
-
178
- cameras = PerspectiveCameras(
179
- device=self.device,
180
- focal_length=focal_length,
181
- principal_point=principal_point,
182
- R=self.R,
183
- T=self.T,
184
- in_ndc=False,
185
- image_size=image_size,
186
- )
187
-
188
- for param in cameras.parameters():
189
- param.requires_grad_(False)
190
-
191
- raster_settings = RasterizationSettings(
192
- image_size=(self.img_h, self.img_w),
193
- blur_radius=0.0,
194
- max_faces_per_bin=20000,
195
- faces_per_pixel=1,
196
- )
197
-
198
- lights = DirectionalLights(
199
- device=self.device,
200
- ambient_color=((1.0, 1.0, 1.0),),
201
- diffuse_color=((0.0, 0.0, 0.0),),
202
- specular_color=((0.0, 0.0, 0.0),),
203
- direction=((0, 1, 0),),
204
- )
205
-
206
- blend_params = BlendParams(background_color=self.background_color)
207
-
208
- shader = HardFlatShader(device=self.device,
209
- cameras=cameras,
210
- blend_params=blend_params,
211
- lights=lights)
212
-
213
- self.textures = TexturesVertex(verts_features=self.vertex_colors)
214
-
215
- self.renderer = MeshRendererWithDepth(
216
- rasterizer=MeshRasterizer(
217
- cameras=cameras,
218
- raster_settings=raster_settings
219
- ),
220
- shader=shader,
221
- )
222
-
223
- def forward(self, vertices, faces=None, R=None, T=None, face_atlas=None):
224
- batch_size = vertices.shape[0]
225
- if faces is None:
226
- faces = self.smpl_faces.expand(batch_size, -1, -1)
227
-
228
- if R is None:
229
- R = self.R.expand(batch_size, -1, -1)
230
-
231
- if T is None:
232
- T = self.T.expand(batch_size, -1)
233
-
234
- # convert camera translation to pytorch3d coordinate frame
235
- T = torch.bmm(R, T.unsqueeze(-1)).squeeze(-1)
236
-
237
- vertex_textures = TexturesVertex(
238
- verts_features=self.vertex_colors.expand(batch_size, -1, -1)
239
- )
240
-
241
- # face_textures needed because vertex_texture cause interpolation at boundaries
242
- if face_atlas:
243
- face_textures = TexturesAtlas(atlas=face_atlas)
244
- else:
245
- face_textures = TexturesAtlas(atlas=self.face_colors)
246
-
247
- # we may need to rotate the mesh
248
- meshes = Meshes(verts=vertices, faces=faces, textures=face_textures)
249
- images = self.renderer(meshes, R=R, T=T)
250
- images = images.permute(0, 3, 1, 2)
251
- return images
252
-
253
-
254
- class NeuralMeshRenderer(DifferentiableRenderer):
255
- def __init__(self, *args, **kwargs):
256
- import neural_renderer as nr
257
-
258
- super(NeuralMeshRenderer, self).__init__(*args, **kwargs)
259
-
260
- self.neural_renderer = nr.Renderer(
261
- dist_coeffs=None,
262
- orig_size=self.img_size,
263
- image_size=self.img_size,
264
- light_intensity_ambient=1,
265
- light_intensity_directional=0,
266
- anti_aliasing=False,
267
- )
268
-
269
- def forward(self, vertices, faces=None, R=None, T=None):
270
- batch_size = vertices.shape[0]
271
- if faces is None:
272
- faces = self.smpl_faces.expand(batch_size, -1, -1)
273
-
274
- if R is None:
275
- R = self.R.expand(batch_size, -1, -1)
276
-
277
- if T is None:
278
- T = self.T.expand(batch_size, -1)
279
- rgb, depth, mask = self.neural_renderer(
280
- vertices,
281
- faces,
282
- textures=self.face_colors.expand(batch_size, -1, -1, -1, -1, -1),
283
- K=self.K.expand(batch_size, -1, -1),
284
- R=R,
285
- t=T.unsqueeze(1),
286
- )
287
- return torch.cat([rgb, depth.unsqueeze(1), mask.unsqueeze(1)], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/get_cfg.py DELETED
@@ -1,17 +0,0 @@
1
- from yacs.config import CfgNode
2
-
3
- _VALID_TYPES = {tuple, list, str, int, float, bool}
4
-
5
-
6
- def convert_to_dict(cfg_node, key_list=[]):
7
- """ Convert a config node to dictionary """
8
- if not isinstance(cfg_node, CfgNode):
9
- if type(cfg_node) not in _VALID_TYPES:
10
- print("Key {} with value {} is not a valid type; valid types: {}".format(
11
- ".".join(key_list), type(cfg_node), _VALID_TYPES), )
12
- return cfg_node
13
- else:
14
- cfg_dict = dict(cfg_node)
15
- for k, v in cfg_dict.items():
16
- cfg_dict[k] = convert_to_dict(v, key_list + [k])
17
- return cfg_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/hrnet.py DELETED
@@ -1,625 +0,0 @@
1
- import os
2
-
3
- import torch
4
- import torch.nn as nn
5
- from loguru import logger
6
- import torch.nn.functional as F
7
- from yacs.config import CfgNode as CN
8
-
9
- models = [
10
- 'hrnet_w32',
11
- 'hrnet_w48',
12
- ]
13
-
14
- BN_MOMENTUM = 0.1
15
-
16
-
17
- def conv3x3(in_planes, out_planes, stride=1):
18
- """3x3 convolution with padding"""
19
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20
- padding=1, bias=False)
21
-
22
-
23
- class BasicBlock(nn.Module):
24
- expansion = 1
25
-
26
- def __init__(self, inplanes, planes, stride=1, downsample=None):
27
- super(BasicBlock, self).__init__()
28
- self.conv1 = conv3x3(inplanes, planes, stride)
29
- self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
30
- self.relu = nn.ReLU(inplace=True)
31
- self.conv2 = conv3x3(planes, planes)
32
- self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
33
- self.downsample = downsample
34
- self.stride = stride
35
-
36
- def forward(self, x):
37
- residual = x
38
-
39
- out = self.conv1(x)
40
- out = self.bn1(out)
41
- out = self.relu(out)
42
-
43
- out = self.conv2(out)
44
- out = self.bn2(out)
45
-
46
- if self.downsample is not None:
47
- residual = self.downsample(x)
48
-
49
- out += residual
50
- out = self.relu(out)
51
-
52
- return out
53
-
54
-
55
- class Bottleneck(nn.Module):
56
- expansion = 4
57
-
58
- def __init__(self, inplanes, planes, stride=1, downsample=None):
59
- super(Bottleneck, self).__init__()
60
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61
- self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
62
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
63
- padding=1, bias=False)
64
- self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
65
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
66
- bias=False)
67
- self.bn3 = nn.BatchNorm2d(planes * self.expansion,
68
- momentum=BN_MOMENTUM)
69
- self.relu = nn.ReLU(inplace=True)
70
- self.downsample = downsample
71
- self.stride = stride
72
-
73
- def forward(self, x):
74
- residual = x
75
-
76
- out = self.conv1(x)
77
- out = self.bn1(out)
78
- out = self.relu(out)
79
-
80
- out = self.conv2(out)
81
- out = self.bn2(out)
82
- out = self.relu(out)
83
-
84
- out = self.conv3(out)
85
- out = self.bn3(out)
86
-
87
- if self.downsample is not None:
88
- residual = self.downsample(x)
89
-
90
- out += residual
91
- out = self.relu(out)
92
-
93
- return out
94
-
95
-
96
- class HighResolutionModule(nn.Module):
97
- def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
98
- num_channels, fuse_method, multi_scale_output=True):
99
- super(HighResolutionModule, self).__init__()
100
- self._check_branches(
101
- num_branches, blocks, num_blocks, num_inchannels, num_channels)
102
-
103
- self.num_inchannels = num_inchannels
104
- self.fuse_method = fuse_method
105
- self.num_branches = num_branches
106
-
107
- self.multi_scale_output = multi_scale_output
108
-
109
- self.branches = self._make_branches(
110
- num_branches, blocks, num_blocks, num_channels)
111
- self.fuse_layers = self._make_fuse_layers()
112
- self.relu = nn.ReLU(True)
113
-
114
- def _check_branches(self, num_branches, blocks, num_blocks,
115
- num_inchannels, num_channels):
116
- if num_branches != len(num_blocks):
117
- error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
118
- num_branches, len(num_blocks))
119
- logger.error(error_msg)
120
- raise ValueError(error_msg)
121
-
122
- if num_branches != len(num_channels):
123
- error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
124
- num_branches, len(num_channels))
125
- logger.error(error_msg)
126
- raise ValueError(error_msg)
127
-
128
- if num_branches != len(num_inchannels):
129
- error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
130
- num_branches, len(num_inchannels))
131
- logger.error(error_msg)
132
- raise ValueError(error_msg)
133
-
134
- def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
135
- stride=1):
136
- downsample = None
137
- if stride != 1 or \
138
- self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
139
- downsample = nn.Sequential(
140
- nn.Conv2d(
141
- self.num_inchannels[branch_index],
142
- num_channels[branch_index] * block.expansion,
143
- kernel_size=1, stride=stride, bias=False
144
- ),
145
- nn.BatchNorm2d(
146
- num_channels[branch_index] * block.expansion,
147
- momentum=BN_MOMENTUM
148
- ),
149
- )
150
-
151
- layers = []
152
- layers.append(
153
- block(
154
- self.num_inchannels[branch_index],
155
- num_channels[branch_index],
156
- stride,
157
- downsample
158
- )
159
- )
160
- self.num_inchannels[branch_index] = \
161
- num_channels[branch_index] * block.expansion
162
- for i in range(1, num_blocks[branch_index]):
163
- layers.append(
164
- block(
165
- self.num_inchannels[branch_index],
166
- num_channels[branch_index]
167
- )
168
- )
169
-
170
- return nn.Sequential(*layers)
171
-
172
- def _make_branches(self, num_branches, block, num_blocks, num_channels):
173
- branches = []
174
-
175
- for i in range(num_branches):
176
- branches.append(
177
- self._make_one_branch(i, block, num_blocks, num_channels)
178
- )
179
-
180
- return nn.ModuleList(branches)
181
-
182
- def _make_fuse_layers(self):
183
- if self.num_branches == 1:
184
- return None
185
-
186
- num_branches = self.num_branches
187
- num_inchannels = self.num_inchannels
188
- fuse_layers = []
189
- for i in range(num_branches if self.multi_scale_output else 1):
190
- fuse_layer = []
191
- for j in range(num_branches):
192
- if j > i:
193
- fuse_layer.append(
194
- nn.Sequential(
195
- nn.Conv2d(
196
- num_inchannels[j],
197
- num_inchannels[i],
198
- 1, 1, 0, bias=False
199
- ),
200
- nn.BatchNorm2d(num_inchannels[i]),
201
- nn.Upsample(scale_factor=2**(j-i), mode='nearest')
202
- )
203
- )
204
- elif j == i:
205
- fuse_layer.append(None)
206
- else:
207
- conv3x3s = []
208
- for k in range(i-j):
209
- if k == i - j - 1:
210
- num_outchannels_conv3x3 = num_inchannels[i]
211
- conv3x3s.append(
212
- nn.Sequential(
213
- nn.Conv2d(
214
- num_inchannels[j],
215
- num_outchannels_conv3x3,
216
- 3, 2, 1, bias=False
217
- ),
218
- nn.BatchNorm2d(num_outchannels_conv3x3)
219
- )
220
- )
221
- else:
222
- num_outchannels_conv3x3 = num_inchannels[j]
223
- conv3x3s.append(
224
- nn.Sequential(
225
- nn.Conv2d(
226
- num_inchannels[j],
227
- num_outchannels_conv3x3,
228
- 3, 2, 1, bias=False
229
- ),
230
- nn.BatchNorm2d(num_outchannels_conv3x3),
231
- nn.ReLU(True)
232
- )
233
- )
234
- fuse_layer.append(nn.Sequential(*conv3x3s))
235
- fuse_layers.append(nn.ModuleList(fuse_layer))
236
-
237
- return nn.ModuleList(fuse_layers)
238
-
239
- def get_num_inchannels(self):
240
- return self.num_inchannels
241
-
242
- def forward(self, x):
243
- if self.num_branches == 1:
244
- return [self.branches[0](x[0])]
245
-
246
- for i in range(self.num_branches):
247
- x[i] = self.branches[i](x[i])
248
-
249
- x_fuse = []
250
-
251
- for i in range(len(self.fuse_layers)):
252
- y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
253
- for j in range(1, self.num_branches):
254
- if i == j:
255
- y = y + x[j]
256
- else:
257
- y = y + self.fuse_layers[i][j](x[j])
258
- x_fuse.append(self.relu(y))
259
-
260
- return x_fuse
261
-
262
-
263
- blocks_dict = {
264
- 'BASIC': BasicBlock,
265
- 'BOTTLENECK': Bottleneck
266
- }
267
-
268
-
269
- class PoseHighResolutionNet(nn.Module):
270
-
271
- def __init__(self, cfg):
272
- self.inplanes = 64
273
- extra = cfg['MODEL']['EXTRA']
274
- super(PoseHighResolutionNet, self).__init__()
275
-
276
- self.cfg = extra
277
-
278
- # stem net
279
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
280
- bias=False)
281
- self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
282
- self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
283
- bias=False)
284
- self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
285
- self.relu = nn.ReLU(inplace=True)
286
- self.layer1 = self._make_layer(Bottleneck, 64, 4)
287
-
288
- self.stage2_cfg = extra['STAGE2']
289
- num_channels = self.stage2_cfg['NUM_CHANNELS']
290
- block = blocks_dict[self.stage2_cfg['BLOCK']]
291
- num_channels = [
292
- num_channels[i] * block.expansion for i in range(len(num_channels))
293
- ]
294
- self.transition1 = self._make_transition_layer([256], num_channels)
295
- self.stage2, pre_stage_channels = self._make_stage(
296
- self.stage2_cfg, num_channels)
297
-
298
- self.stage3_cfg = extra['STAGE3']
299
- num_channels = self.stage3_cfg['NUM_CHANNELS']
300
- block = blocks_dict[self.stage3_cfg['BLOCK']]
301
- num_channels = [
302
- num_channels[i] * block.expansion for i in range(len(num_channels))
303
- ]
304
- self.transition2 = self._make_transition_layer(
305
- pre_stage_channels, num_channels)
306
- self.stage3, pre_stage_channels = self._make_stage(
307
- self.stage3_cfg, num_channels)
308
-
309
- self.stage4_cfg = extra['STAGE4']
310
- num_channels = self.stage4_cfg['NUM_CHANNELS']
311
- block = blocks_dict[self.stage4_cfg['BLOCK']]
312
- num_channels = [
313
- num_channels[i] * block.expansion for i in range(len(num_channels))
314
- ]
315
- self.transition3 = self._make_transition_layer(
316
- pre_stage_channels, num_channels)
317
- self.stage4, pre_stage_channels = self._make_stage(
318
- self.stage4_cfg, num_channels, multi_scale_output=True)
319
-
320
- self.final_layer = nn.Conv2d(
321
- in_channels=pre_stage_channels[0],
322
- out_channels=cfg['MODEL']['NUM_JOINTS'],
323
- kernel_size=extra['FINAL_CONV_KERNEL'],
324
- stride=1,
325
- padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0
326
- )
327
-
328
- self.pretrained_layers = extra['PRETRAINED_LAYERS']
329
-
330
- if extra.DOWNSAMPLE and extra.USE_CONV:
331
- self.downsample_stage_1 = self._make_downsample_layer(3, num_channel=self.stage2_cfg['NUM_CHANNELS'][0])
332
- self.downsample_stage_2 = self._make_downsample_layer(2, num_channel=self.stage2_cfg['NUM_CHANNELS'][-1])
333
- self.downsample_stage_3 = self._make_downsample_layer(1, num_channel=self.stage3_cfg['NUM_CHANNELS'][-1])
334
- elif not extra.DOWNSAMPLE and extra.USE_CONV:
335
- self.upsample_stage_2 = self._make_upsample_layer(1, num_channel=self.stage2_cfg['NUM_CHANNELS'][-1])
336
- self.upsample_stage_3 = self._make_upsample_layer(2, num_channel=self.stage3_cfg['NUM_CHANNELS'][-1])
337
- self.upsample_stage_4 = self._make_upsample_layer(3, num_channel=self.stage4_cfg['NUM_CHANNELS'][-1])
338
-
339
- def _make_transition_layer(
340
- self, num_channels_pre_layer, num_channels_cur_layer):
341
- num_branches_cur = len(num_channels_cur_layer)
342
- num_branches_pre = len(num_channels_pre_layer)
343
-
344
- transition_layers = []
345
- for i in range(num_branches_cur):
346
- if i < num_branches_pre:
347
- if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
348
- transition_layers.append(
349
- nn.Sequential(
350
- nn.Conv2d(
351
- num_channels_pre_layer[i],
352
- num_channels_cur_layer[i],
353
- 3, 1, 1, bias=False
354
- ),
355
- nn.BatchNorm2d(num_channels_cur_layer[i]),
356
- nn.ReLU(inplace=True)
357
- )
358
- )
359
- else:
360
- transition_layers.append(None)
361
- else:
362
- conv3x3s = []
363
- for j in range(i+1-num_branches_pre):
364
- inchannels = num_channels_pre_layer[-1]
365
- outchannels = num_channels_cur_layer[i] \
366
- if j == i-num_branches_pre else inchannels
367
- conv3x3s.append(
368
- nn.Sequential(
369
- nn.Conv2d(
370
- inchannels, outchannels, 3, 2, 1, bias=False
371
- ),
372
- nn.BatchNorm2d(outchannels),
373
- nn.ReLU(inplace=True)
374
- )
375
- )
376
- transition_layers.append(nn.Sequential(*conv3x3s))
377
-
378
- return nn.ModuleList(transition_layers)
379
-
380
- def _make_layer(self, block, planes, blocks, stride=1):
381
- downsample = None
382
- if stride != 1 or self.inplanes != planes * block.expansion:
383
- downsample = nn.Sequential(
384
- nn.Conv2d(
385
- self.inplanes, planes * block.expansion,
386
- kernel_size=1, stride=stride, bias=False
387
- ),
388
- nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
389
- )
390
-
391
- layers = []
392
- layers.append(block(self.inplanes, planes, stride, downsample))
393
- self.inplanes = planes * block.expansion
394
- for i in range(1, blocks):
395
- layers.append(block(self.inplanes, planes))
396
-
397
- return nn.Sequential(*layers)
398
-
399
- def _make_stage(self, layer_config, num_inchannels,
400
- multi_scale_output=True):
401
- num_modules = layer_config['NUM_MODULES']
402
- num_branches = layer_config['NUM_BRANCHES']
403
- num_blocks = layer_config['NUM_BLOCKS']
404
- num_channels = layer_config['NUM_CHANNELS']
405
- block = blocks_dict[layer_config['BLOCK']]
406
- fuse_method = layer_config['FUSE_METHOD']
407
-
408
- modules = []
409
- for i in range(num_modules):
410
- # multi_scale_output is only used last module
411
- if not multi_scale_output and i == num_modules - 1:
412
- reset_multi_scale_output = False
413
- else:
414
- reset_multi_scale_output = True
415
-
416
- modules.append(
417
- HighResolutionModule(
418
- num_branches,
419
- block,
420
- num_blocks,
421
- num_inchannels,
422
- num_channels,
423
- fuse_method,
424
- reset_multi_scale_output
425
- )
426
- )
427
- num_inchannels = modules[-1].get_num_inchannels()
428
-
429
- return nn.Sequential(*modules), num_inchannels
430
-
431
- def _make_upsample_layer(self, num_layers, num_channel, kernel_size=3):
432
- layers = []
433
- for i in range(num_layers):
434
- layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
435
- layers.append(
436
- nn.Conv2d(
437
- in_channels=num_channel, out_channels=num_channel,
438
- kernel_size=kernel_size, stride=1, padding=1, bias=False,
439
- )
440
- )
441
- layers.append(nn.BatchNorm2d(num_channel, momentum=BN_MOMENTUM))
442
- layers.append(nn.ReLU(inplace=True))
443
-
444
- return nn.Sequential(*layers)
445
-
446
- def _make_downsample_layer(self, num_layers, num_channel, kernel_size=3):
447
- layers = []
448
- for i in range(num_layers):
449
- layers.append(
450
- nn.Conv2d(
451
- in_channels=num_channel, out_channels=num_channel,
452
- kernel_size=kernel_size, stride=2, padding=1, bias=False,
453
- )
454
- )
455
- layers.append(nn.BatchNorm2d(num_channel, momentum=BN_MOMENTUM))
456
- layers.append(nn.ReLU(inplace=True))
457
-
458
- return nn.Sequential(*layers)
459
-
460
- def forward(self, x):
461
- x = self.conv1(x)
462
- x = self.bn1(x)
463
- x = self.relu(x)
464
- x = self.conv2(x)
465
- x = self.bn2(x)
466
- x = self.relu(x)
467
- x = self.layer1(x)
468
-
469
- x_list = []
470
- for i in range(self.stage2_cfg['NUM_BRANCHES']):
471
- if self.transition1[i] is not None:
472
- x_list.append(self.transition1[i](x))
473
- else:
474
- x_list.append(x)
475
- y_list = self.stage2(x_list)
476
-
477
- x_list = []
478
- for i in range(self.stage3_cfg['NUM_BRANCHES']):
479
- if self.transition2[i] is not None:
480
- x_list.append(self.transition2[i](y_list[-1]))
481
- else:
482
- x_list.append(y_list[i])
483
- y_list = self.stage3(x_list)
484
-
485
- x_list = []
486
- for i in range(self.stage4_cfg['NUM_BRANCHES']):
487
- if self.transition3[i] is not None:
488
- x_list.append(self.transition3[i](y_list[-1]))
489
- else:
490
- x_list.append(y_list[i])
491
- x = self.stage4(x_list)
492
-
493
- if self.cfg.DOWNSAMPLE:
494
- if self.cfg.USE_CONV:
495
- # Downsampling with strided convolutions
496
- x1 = self.downsample_stage_1(x[0])
497
- x2 = self.downsample_stage_2(x[1])
498
- x3 = self.downsample_stage_3(x[2])
499
- x = torch.cat([x1, x2, x3, x[3]], 1)
500
- else:
501
- # Downsampling with interpolation
502
- x0_h, x0_w = x[3].size(2), x[3].size(3)
503
- x1 = F.interpolate(x[0], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
504
- x2 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
505
- x3 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
506
- x = torch.cat([x1, x2, x3, x[3]], 1)
507
- else:
508
- if self.cfg.USE_CONV:
509
- # Upsampling with interpolations + convolutions
510
- x1 = self.upsample_stage_2(x[1])
511
- x2 = self.upsample_stage_3(x[2])
512
- x3 = self.upsample_stage_4(x[3])
513
- x = torch.cat([x[0], x1, x2, x3], 1)
514
- else:
515
- # Upsampling with interpolation
516
- x0_h, x0_w = x[0].size(2), x[0].size(3)
517
- x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
518
- x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
519
- x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
520
- x = torch.cat([x[0], x1, x2, x3], 1)
521
-
522
- return x
523
-
524
- def init_weights(self, pretrained=''):
525
- logger.info('=> init weights from normal distribution')
526
- for m in self.modules():
527
- if isinstance(m, nn.Conv2d):
528
- # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
529
- nn.init.normal_(m.weight, std=0.001)
530
- for name, _ in m.named_parameters():
531
- if name in ['bias']:
532
- nn.init.constant_(m.bias, 0)
533
- elif isinstance(m, nn.BatchNorm2d):
534
- nn.init.constant_(m.weight, 1)
535
- nn.init.constant_(m.bias, 0)
536
- elif isinstance(m, nn.ConvTranspose2d):
537
- nn.init.normal_(m.weight, std=0.001)
538
- for name, _ in m.named_parameters():
539
- if name in ['bias']:
540
- nn.init.constant_(m.bias, 0)
541
-
542
- if os.path.isfile(pretrained):
543
- pretrained_state_dict = torch.load(pretrained)
544
- logger.info('=> loading pretrained model {}'.format(pretrained))
545
-
546
- need_init_state_dict = {}
547
- for name, m in pretrained_state_dict.items():
548
- if name.split('.')[0] in self.pretrained_layers \
549
- or self.pretrained_layers[0] is '*':
550
- need_init_state_dict[name] = m
551
- self.load_state_dict(need_init_state_dict, strict=False)
552
- elif pretrained:
553
- logger.warning('IMPORTANT WARNING!! Please download pre-trained models if you are in TRAINING mode!')
554
- # raise ValueError('{} is not exist!'.format(pretrained))
555
-
556
-
557
- def get_pose_net(cfg, is_train):
558
- model = PoseHighResolutionNet(cfg)
559
-
560
- if is_train and cfg['MODEL']['INIT_WEIGHTS']:
561
- model.init_weights(cfg['MODEL']['PRETRAINED'])
562
-
563
- return model
564
-
565
-
566
- def get_cfg_defaults(pretrained, width=32, downsample=False, use_conv=False):
567
- # pose_multi_resoluton_net related params
568
- HRNET = CN()
569
- HRNET.PRETRAINED_LAYERS = [
570
- 'conv1', 'bn1', 'conv2', 'bn2', 'layer1', 'transition1',
571
- 'stage2', 'transition2', 'stage3', 'transition3', 'stage4',
572
- ]
573
- HRNET.STEM_INPLANES = 64
574
- HRNET.FINAL_CONV_KERNEL = 1
575
- HRNET.STAGE2 = CN()
576
- HRNET.STAGE2.NUM_MODULES = 1
577
- HRNET.STAGE2.NUM_BRANCHES = 2
578
- HRNET.STAGE2.NUM_BLOCKS = [4, 4]
579
- HRNET.STAGE2.NUM_CHANNELS = [width, width*2]
580
- HRNET.STAGE2.BLOCK = 'BASIC'
581
- HRNET.STAGE2.FUSE_METHOD = 'SUM'
582
- HRNET.STAGE3 = CN()
583
- HRNET.STAGE3.NUM_MODULES = 4
584
- HRNET.STAGE3.NUM_BRANCHES = 3
585
- HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
586
- HRNET.STAGE3.NUM_CHANNELS = [width, width*2, width*4]
587
- HRNET.STAGE3.BLOCK = 'BASIC'
588
- HRNET.STAGE3.FUSE_METHOD = 'SUM'
589
- HRNET.STAGE4 = CN()
590
- HRNET.STAGE4.NUM_MODULES = 3
591
- HRNET.STAGE4.NUM_BRANCHES = 4
592
- HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
593
- HRNET.STAGE4.NUM_CHANNELS = [width, width*2, width*4, width*8]
594
- HRNET.STAGE4.BLOCK = 'BASIC'
595
- HRNET.STAGE4.FUSE_METHOD = 'SUM'
596
- HRNET.DOWNSAMPLE = downsample
597
- HRNET.USE_CONV = use_conv
598
-
599
- cfg = CN()
600
- cfg.MODEL = CN()
601
- cfg.MODEL.INIT_WEIGHTS = True
602
- cfg.MODEL.PRETRAINED = pretrained # 'data/pretrained_models/hrnet_w32-36af842e.pth'
603
- cfg.MODEL.EXTRA = HRNET
604
- cfg.MODEL.NUM_JOINTS = 24
605
- return cfg
606
-
607
-
608
- def hrnet_w32(
609
- pretrained=True,
610
- pretrained_ckpt='data/weights/pose_hrnet_w32_256x192.pth',
611
- downsample=False,
612
- use_conv=False,
613
- ):
614
- cfg = get_cfg_defaults(pretrained_ckpt, width=32, downsample=downsample, use_conv=use_conv)
615
- return get_pose_net(cfg, is_train=True)
616
-
617
-
618
- def hrnet_w48(
619
- pretrained=True,
620
- pretrained_ckpt='data/weights/pose_hrnet_w48_256x192.pth',
621
- downsample=False,
622
- use_conv=False,
623
- ):
624
- cfg = get_cfg_defaults(pretrained_ckpt, width=48, downsample=downsample, use_conv=use_conv)
625
- return get_pose_net(cfg, is_train=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/image_utils.py DELETED
@@ -1,444 +0,0 @@
1
- """
2
- This file contains functions that are used to perform data augmentation.
3
- """
4
- import cv2
5
- import torch
6
- import json
7
- from skimage.transform import rotate, resize
8
- import numpy as np
9
- import jpeg4py as jpeg
10
- from trimesh.visual import color
11
-
12
- # from ..core import constants
13
- # from .vibe_image_utils import gen_trans_from_patch_cv
14
- from .kp_utils import map_smpl_to_common, get_smpl_joint_names
15
-
16
- def get_transform(center, scale, res, rot=0):
17
- """Generate transformation matrix."""
18
- h = 200 * scale
19
- t = np.zeros((3, 3))
20
- t[0, 0] = float(res[1]) / h
21
- t[1, 1] = float(res[0]) / h
22
- t[0, 2] = res[1] * (-float(center[0]) / h + .5)
23
- t[1, 2] = res[0] * (-float(center[1]) / h + .5)
24
- t[2, 2] = 1
25
- if not rot == 0:
26
- rot = -rot # To match direction of rotation from cropping
27
- rot_mat = np.zeros((3, 3))
28
- rot_rad = rot * np.pi / 180
29
- sn, cs = np.sin(rot_rad), np.cos(rot_rad)
30
- rot_mat[0, :2] = [cs, -sn]
31
- rot_mat[1, :2] = [sn, cs]
32
- rot_mat[2, 2] = 1
33
- # Need to rotate around center
34
- t_mat = np.eye(3)
35
- t_mat[0, 2] = -res[1] / 2
36
- t_mat[1, 2] = -res[0] / 2
37
- t_inv = t_mat.copy()
38
- t_inv[:2, 2] *= -1
39
- t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
40
- return t
41
-
42
-
43
- def transform(pt, center, scale, res, invert=0, rot=0):
44
- """Transform pixel location to different reference."""
45
- t = get_transform(center, scale, res, rot=rot)
46
- if invert:
47
- t = np.linalg.inv(t)
48
- new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
49
- new_pt = np.dot(t, new_pt)
50
- return new_pt[:2].astype(int) + 1
51
-
52
-
53
- def crop(img, center, scale, res, rot=0):
54
- """Crop image according to the supplied bounding box."""
55
- # Upper left point
56
- ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
57
- # Bottom right point
58
- br = np.array(transform([res[0] + 1,
59
- res[1] + 1], center, scale, res, invert=1)) - 1
60
-
61
- # Padding so that when rotated proper amount of context is included
62
- pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
63
- if not rot == 0:
64
- ul -= pad
65
- br += pad
66
-
67
- new_shape = [br[1] - ul[1], br[0] - ul[0]]
68
- if len(img.shape) > 2:
69
- new_shape += [img.shape[2]]
70
- new_img = np.zeros(new_shape)
71
-
72
- # Range to fill new array
73
- new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
74
- new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
75
- # Range to sample from original image
76
- old_x = max(0, ul[0]), min(len(img[0]), br[0])
77
- old_y = max(0, ul[1]), min(len(img), br[1])
78
- new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
79
- old_x[0]:old_x[1]]
80
-
81
- if not rot == 0:
82
- # Remove padding
83
-
84
- new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
85
- new_img = new_img[pad:-pad, pad:-pad]
86
-
87
- # resize image
88
- new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
89
- return new_img
90
-
91
-
92
- def crop_cv2(img, center, scale, res, rot=0):
93
- c_x, c_y = center
94
- c_x, c_y = int(round(c_x)), int(round(c_y))
95
- patch_width, patch_height = int(round(res[0])), int(round(res[1]))
96
- bb_width = bb_height = int(round(scale * 200.))
97
-
98
- trans = gen_trans_from_patch_cv(
99
- c_x, c_y, bb_width, bb_height,
100
- patch_width, patch_height,
101
- scale=1.0, rot=rot, inv=False,
102
- )
103
-
104
- crop_img = cv2.warpAffine(
105
- img, trans, (int(patch_width), int(patch_height)),
106
- flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT
107
- )
108
-
109
- return crop_img
110
-
111
-
112
- def get_random_crop_coords(height, width, crop_height, crop_width, h_start, w_start):
113
- y1 = int((height - crop_height) * h_start)
114
- y2 = y1 + crop_height
115
- x1 = int((width - crop_width) * w_start)
116
- x2 = x1 + crop_width
117
- return x1, y1, x2, y2
118
-
119
-
120
- def random_crop(center, scale, crop_scale_factor, axis='all'):
121
- '''
122
- center: bbox center [x,y]
123
- scale: bbox height / 200
124
- crop_scale_factor: amount of cropping to be applied
125
- axis: axis which cropping will be applied
126
- "x": center the y axis and get random crops in x
127
- "y": center the x axis and get random crops in y
128
- "all": randomly crop from all locations
129
- '''
130
- orig_size = int(scale * 200.)
131
- ul = (center - (orig_size / 2.)).astype(int)
132
-
133
- crop_size = int(orig_size * crop_scale_factor)
134
-
135
- if axis == 'all':
136
- h_start = np.random.rand()
137
- w_start = np.random.rand()
138
- elif axis == 'x':
139
- h_start = np.random.rand()
140
- w_start = 0.5
141
- elif axis == 'y':
142
- h_start = 0.5
143
- w_start = np.random.rand()
144
- else:
145
- raise ValueError(f'axis {axis} is undefined!')
146
-
147
- x1, y1, x2, y2 = get_random_crop_coords(
148
- height=orig_size,
149
- width=orig_size,
150
- crop_height=crop_size,
151
- crop_width=crop_size,
152
- h_start=h_start,
153
- w_start=w_start,
154
- )
155
- scale = (y2 - y1) / 200.
156
- center = ul + np.array([(y1 + y2) / 2, (x1 + x2) / 2])
157
- return center, scale
158
-
159
-
160
- def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
161
- """'Undo' the image cropping/resizing.
162
- This function is used when evaluating mask/part segmentation.
163
- """
164
- res = img.shape[:2]
165
- # Upper left point
166
- ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
167
- # Bottom right point
168
- br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1
169
- # size of cropped image
170
- crop_shape = [br[1] - ul[1], br[0] - ul[0]]
171
-
172
- new_shape = [br[1] - ul[1], br[0] - ul[0]]
173
- if len(img.shape) > 2:
174
- new_shape += [img.shape[2]]
175
- new_img = np.zeros(orig_shape, dtype=np.uint8)
176
- # Range to fill new array
177
- new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
178
- new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
179
- # Range to sample from original image
180
- old_x = max(0, ul[0]), min(orig_shape[1], br[0])
181
- old_y = max(0, ul[1]), min(orig_shape[0], br[1])
182
- img = resize(img, crop_shape) #, interp='nearest') # scipy.misc.imresize(img, crop_shape, interp='nearest')
183
- new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
184
- return new_img
185
-
186
-
187
- def rot_aa(aa, rot):
188
- """Rotate axis angle parameters."""
189
- # pose parameters
190
- R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
191
- [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
192
- [0, 0, 1]])
193
- # find the rotation of the body in camera frame
194
- per_rdg, _ = cv2.Rodrigues(aa)
195
- # apply the global rotation to the global orientation
196
- resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
197
- aa = (resrot.T)[0]
198
- return aa
199
-
200
-
201
- def flip_img(img):
202
- """Flip rgb images or masks.
203
- channels come last, e.g. (256,256,3).
204
- """
205
- img = np.fliplr(img)
206
- return img
207
-
208
-
209
- def flip_kp(kp):
210
- """Flip keypoints."""
211
- if len(kp) == 24:
212
- flipped_parts = constants.J24_FLIP_PERM
213
- elif len(kp) == 49:
214
- flipped_parts = constants.J49_FLIP_PERM
215
- kp = kp[flipped_parts]
216
- kp[:, 0] = - kp[:, 0]
217
- return kp
218
-
219
-
220
- def flip_pose(pose):
221
- """Flip pose.
222
- The flipping is based on SMPL parameters.
223
- """
224
- flipped_parts = constants.SMPL_POSE_FLIP_PERM
225
- pose = pose[flipped_parts]
226
- # we also negate the second and the third dimension of the axis-angle
227
- pose[1::3] = -pose[1::3]
228
- pose[2::3] = -pose[2::3]
229
- return pose
230
-
231
-
232
- def denormalize_images(images):
233
- images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1)
234
- images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1)
235
- return images
236
-
237
-
238
- def read_img(img_fn):
239
- # return pil_img.fromarray(
240
- # cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB))
241
- # with open(img_fn, 'rb') as f:
242
- # img = pil_img.open(f).convert('RGB')
243
- # return img
244
- if img_fn.endswith('jpeg') or img_fn.endswith('jpg'):
245
- try:
246
- with open(img_fn, 'rb') as f:
247
- img = np.array(jpeg.JPEG(f).decode())
248
- except jpeg.JPEGRuntimeError:
249
- # logger.warning('{} produced a JPEGRuntimeError', img_fn)
250
- img = cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB)
251
- else:
252
- # elif img_fn.endswith('png') or img_fn.endswith('JPG') or img_fn.endswith(''):
253
- img = cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB)
254
- return img.astype(np.float32)
255
-
256
-
257
- def generate_heatmaps_2d(joints, joints_vis, num_joints=24, heatmap_size=56, image_size=224, sigma=1.75):
258
- '''
259
- :param joints: [num_joints, 3]
260
- :param joints_vis: [num_joints, 3]
261
- :return: target, target_weight(1: visible, 0: invisible)
262
- '''
263
- target_weight = np.ones((num_joints, 1), dtype=np.float32)
264
- target_weight[:, 0] = joints_vis[:, 0]
265
-
266
- target = np.zeros((num_joints, heatmap_size, heatmap_size), dtype=np.float32)
267
-
268
- tmp_size = sigma * 3
269
-
270
- # denormalize joint into heatmap coordinates
271
- joints = (joints + 1.) * (image_size / 2.)
272
-
273
- for joint_id in range(num_joints):
274
- feat_stride = image_size / heatmap_size
275
- mu_x = int(joints[joint_id][0] / feat_stride + 0.5)
276
- mu_y = int(joints[joint_id][1] / feat_stride + 0.5)
277
- # Check that any part of the gaussian is in-bounds
278
- ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
279
- br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
280
- if ul[0] >= heatmap_size or ul[1] >= heatmap_size \
281
- or br[0] < 0 or br[1] < 0:
282
- # If not, just return the image as is
283
- target_weight[joint_id] = 0
284
- continue
285
-
286
- # # Generate gaussian
287
- size = 2 * tmp_size + 1
288
- x = np.arange(0, size, 1, np.float32)
289
- y = x[:, np.newaxis]
290
- x0 = y0 = size // 2
291
- # The gaussian is not normalized, we want the center value to equal 1
292
- g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
293
-
294
- # Usable gaussian range
295
- g_x = max(0, -ul[0]), min(br[0], heatmap_size) - ul[0]
296
- g_y = max(0, -ul[1]), min(br[1], heatmap_size) - ul[1]
297
- # Image range
298
- img_x = max(0, ul[0]), min(br[0], heatmap_size)
299
- img_y = max(0, ul[1]), min(br[1], heatmap_size)
300
-
301
- v = target_weight[joint_id]
302
- if v > 0.5:
303
- target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
304
- g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
305
-
306
- return target, target_weight
307
-
308
-
309
- def generate_part_labels(vertices, faces, cam_t, neural_renderer, body_part_texture, K, R, part_bins):
310
- batch_size = vertices.shape[0]
311
-
312
- body_parts, depth, mask = neural_renderer(
313
- vertices,
314
- faces.expand(batch_size, -1, -1),
315
- textures=body_part_texture.expand(batch_size, -1, -1, -1, -1, -1),
316
- K=K.expand(batch_size, -1, -1),
317
- R=R.expand(batch_size, -1, -1),
318
- t=cam_t.unsqueeze(1),
319
- )
320
-
321
- render_rgb = body_parts.clone()
322
-
323
- body_parts = body_parts.permute(0, 2, 3, 1)
324
- body_parts *= 255. # multiply it with 255 to make labels distant
325
- body_parts, _ = body_parts.max(-1) # reduce to single channel
326
-
327
- body_parts = torch.bucketize(body_parts.detach(), part_bins, right=True) # np.digitize(body_parts, bins, right=True)
328
-
329
- # add 1 to make background label 0
330
- body_parts = body_parts.long() + 1
331
- body_parts = body_parts * mask.detach()
332
-
333
- return body_parts.long(), render_rgb
334
-
335
-
336
- def generate_heatmaps_2d_batch(joints, num_joints=24, heatmap_size=56, image_size=224, sigma=1.75):
337
- batch_size = joints.shape[0]
338
-
339
- joints = joints.detach().cpu().numpy()
340
- joints_vis = np.ones_like(joints)
341
-
342
- heatmaps = []
343
- heatmaps_vis = []
344
- for i in range(batch_size):
345
- hm, hm_vis = generate_heatmaps_2d(joints[i], joints_vis[i], num_joints, heatmap_size, image_size, sigma)
346
- heatmaps.append(hm)
347
- heatmaps_vis.append(hm_vis)
348
-
349
- return torch.from_numpy(np.stack(heatmaps)).float().to('cuda'), \
350
- torch.from_numpy(np.stack(heatmaps_vis)).float().to('cuda')
351
-
352
-
353
- def get_body_part_texture(faces, model_type='smpl', non_parametric=False):
354
- if model_type == 'smpl':
355
- n_vertices = 6890
356
- segmentation_path = 'data/smpl_vert_segmentation.json'
357
- if model_type == 'smplx':
358
- n_vertices = 10475
359
- segmentation_path = 'data/smplx_vert_segmentation.json'
360
-
361
- with open(segmentation_path, 'rb') as f:
362
- part_segmentation = json.load(f)
363
-
364
- # map all vertex ids to the joint ids
365
- joint_names = get_smpl_joint_names()
366
- smplx_extra_joint_names = ['leftEye', 'eyeballs', 'rightEye']
367
- body_vert_idx = np.zeros((n_vertices), dtype=np.int32) - 1 # -1 for missing label
368
- for i, (k, v) in enumerate(part_segmentation.items()):
369
- if k in smplx_extra_joint_names and model_type == 'smplx':
370
- k = 'head' # map all extra smplx face joints to head
371
- body_joint_idx = joint_names.index(k)
372
- body_vert_idx[v] = body_joint_idx
373
-
374
- # pare implementation
375
- # import joblib
376
- # part_segmentation = joblib.load('data/smpl_partSegmentation_mapping.pkl')
377
- # body_vert_idx = part_segmentation['smpl_index']
378
-
379
- n_parts = 24.
380
-
381
- if non_parametric:
382
- # reduce the number of body_parts to 14
383
- # by mapping some joints to others
384
- n_parts = 14.
385
- joint_mapping = map_smpl_to_common()
386
-
387
- for jm in joint_mapping:
388
- for j in jm[0]:
389
- body_vert_idx[body_vert_idx==j] = jm[1]
390
-
391
- vertex_colors = np.ones((n_vertices, 4))
392
- vertex_colors[:, :3] = body_vert_idx[..., None]
393
-
394
- vertex_colors = color.to_rgba(vertex_colors)
395
- vertex_colors = vertex_colors[:, :3]/255.
396
-
397
- face_colors = vertex_colors[faces].min(axis=1)
398
- texture = np.zeros((1, faces.shape[0], 1, 1, 3), dtype=np.float32)
399
- # texture[0, :, 0, 0, :] = face_colors[:, :3] / n_parts
400
- texture[0, :, 0, 0, :] = face_colors[:, :3]
401
-
402
- vertex_colors = torch.from_numpy(vertex_colors).float()
403
- texture = torch.from_numpy(texture).float()
404
- return vertex_colors, texture
405
-
406
-
407
- def get_default_camera(focal_length, img_h, img_w, is_cam_batch=False):
408
- if not is_cam_batch:
409
- K = torch.eye(3)
410
- K[0, 0] = focal_length
411
- K[1, 1] = focal_length
412
- K[2, 2] = 1
413
- K[0, 2] = img_w / 2.
414
- K[1, 2] = img_h / 2.
415
- K = K[None, :, :]
416
- R = torch.eye(3)[None, :, :]
417
- else:
418
- bs = focal_length.shape[0]
419
- K = torch.eye(3)[None, :, :].repeat(bs, 1, 1)
420
- K[:, 0, 0] = focal_length[:, 0]
421
- K[:, 1, 1] = focal_length[:, 1]
422
- K[:, 2, 2] = 1
423
- K[:, 0, 2] = img_w / 2.
424
- K[:, 1, 2] = img_h / 2.
425
- R = torch.eye(3)[None, :, :].repeat(bs, 1, 1)
426
- return K, R
427
-
428
-
429
- def read_exif_data(img_fname):
430
- import PIL.Image
431
- import PIL.ExifTags
432
-
433
- img = PIL.Image.open(img_fname)
434
- exif_data = img._getexif()
435
-
436
- if exif_data == None:
437
- return None
438
-
439
- exif = {
440
- PIL.ExifTags.TAGS[k]: v
441
- for k, v in exif_data.items()
442
- if k in PIL.ExifTags.TAGS
443
- }
444
- return exif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/kp_utils.py DELETED
@@ -1,1114 +0,0 @@
1
- import numpy as np
2
-
3
-
4
- def keypoint_hflip(kp, img_width):
5
- # Flip a keypoint horizontally around the y-axis
6
- # kp N,2
7
- if len(kp.shape) == 2:
8
- kp[:,0] = (img_width - 1.) - kp[:,0]
9
- elif len(kp.shape) == 3:
10
- kp[:, :, 0] = (img_width - 1.) - kp[:, :, 0]
11
- return kp
12
-
13
-
14
- def convert_kps(joints2d, src, dst):
15
- src_names = eval(f'get_{src}_joint_names')()
16
- dst_names = eval(f'get_{dst}_joint_names')()
17
-
18
- out_joints2d = np.zeros((joints2d.shape[0], len(dst_names), joints2d.shape[-1]))
19
-
20
- for idx, jn in enumerate(dst_names):
21
- if jn in src_names:
22
- out_joints2d[:, idx] = joints2d[:, src_names.index(jn)]
23
-
24
- return out_joints2d
25
-
26
-
27
- def get_perm_idxs(src, dst):
28
- src_names = eval(f'get_{src}_joint_names')()
29
- dst_names = eval(f'get_{dst}_joint_names')()
30
- idxs = [src_names.index(h) for h in dst_names if h in src_names]
31
- return idxs
32
-
33
-
34
- def get_mpii3d_test_joint_names():
35
- return [
36
- 'headtop', # 'head_top',
37
- 'neck',
38
- 'rshoulder',# 'right_shoulder',
39
- 'relbow',# 'right_elbow',
40
- 'rwrist',# 'right_wrist',
41
- 'lshoulder',# 'left_shoulder',
42
- 'lelbow', # 'left_elbow',
43
- 'lwrist', # 'left_wrist',
44
- 'rhip', # 'right_hip',
45
- 'rknee', # 'right_knee',
46
- 'rankle',# 'right_ankle',
47
- 'lhip',# 'left_hip',
48
- 'lknee',# 'left_knee',
49
- 'lankle',# 'left_ankle'
50
- 'hip',# 'pelvis',
51
- 'Spine (H36M)',# 'spine',
52
- 'Head (H36M)',# 'head'
53
- ]
54
-
55
-
56
- def get_mpii3d_joint_names():
57
- return [
58
- 'spine3', # 0,
59
- 'spine4', # 1,
60
- 'spine2', # 2,
61
- 'Spine (H36M)', #'spine', # 3,
62
- 'hip', # 'pelvis', # 4,
63
- 'neck', # 5,
64
- 'Head (H36M)', # 'head', # 6,
65
- "headtop", # 'head_top', # 7,
66
- 'left_clavicle', # 8,
67
- "lshoulder", # 'left_shoulder', # 9,
68
- "lelbow", # 'left_elbow',# 10,
69
- "lwrist", # 'left_wrist',# 11,
70
- 'left_hand',# 12,
71
- 'right_clavicle',# 13,
72
- 'rshoulder',# 'right_shoulder',# 14,
73
- 'relbow',# 'right_elbow',# 15,
74
- 'rwrist',# 'right_wrist',# 16,
75
- 'right_hand',# 17,
76
- 'lhip', # left_hip',# 18,
77
- 'lknee', # 'left_knee',# 19,
78
- 'lankle', #left ankle # 20
79
- 'left_foot', # 21
80
- 'left_toe', # 22
81
- "rhip", # 'right_hip',# 23
82
- "rknee", # 'right_knee',# 24
83
- "rankle", #'right_ankle', # 25
84
- 'right_foot',# 26
85
- 'right_toe' # 27
86
- ]
87
-
88
-
89
- # def get_insta_joint_names():
90
- # return [
91
- # 'rheel' , # 0
92
- # 'rknee' , # 1
93
- # 'rhip' , # 2
94
- # 'lhip' , # 3
95
- # 'lknee' , # 4
96
- # 'lheel' , # 5
97
- # 'rwrist' , # 6
98
- # 'relbow' , # 7
99
- # 'rshoulder' , # 8
100
- # 'lshoulder' , # 9
101
- # 'lelbow' , # 10
102
- # 'lwrist' , # 11
103
- # 'neck' , # 12
104
- # 'headtop' , # 13
105
- # 'nose' , # 14
106
- # 'leye' , # 15
107
- # 'reye' , # 16
108
- # 'lear' , # 17
109
- # 'rear' , # 18
110
- # 'lbigtoe' , # 19
111
- # 'rbigtoe' , # 20
112
- # 'lsmalltoe' , # 21
113
- # 'rsmalltoe' , # 22
114
- # 'lankle' , # 23
115
- # 'rankle' , # 24
116
- # ]
117
-
118
-
119
- def get_insta_joint_names():
120
- return [
121
- 'OP RHeel',
122
- 'OP RKnee',
123
- 'OP RHip',
124
- 'OP LHip',
125
- 'OP LKnee',
126
- 'OP LHeel',
127
- 'OP RWrist',
128
- 'OP RElbow',
129
- 'OP RShoulder',
130
- 'OP LShoulder',
131
- 'OP LElbow',
132
- 'OP LWrist',
133
- 'OP Neck',
134
- 'headtop',
135
- 'OP Nose',
136
- 'OP LEye',
137
- 'OP REye',
138
- 'OP LEar',
139
- 'OP REar',
140
- 'OP LBigToe',
141
- 'OP RBigToe',
142
- 'OP LSmallToe',
143
- 'OP RSmallToe',
144
- 'OP LAnkle',
145
- 'OP RAnkle',
146
- ]
147
-
148
-
149
- def get_mmpose_joint_names():
150
- # this naming is for the first 23 joints of MMPose
151
- # does not include hands and face
152
- return [
153
- 'OP Nose', # 1
154
- 'OP LEye', # 2
155
- 'OP REye', # 3
156
- 'OP LEar', # 4
157
- 'OP REar', # 5
158
- 'OP LShoulder', # 6
159
- 'OP RShoulder', # 7
160
- 'OP LElbow', # 8
161
- 'OP RElbow', # 9
162
- 'OP LWrist', # 10
163
- 'OP RWrist', # 11
164
- 'OP LHip', # 12
165
- 'OP RHip', # 13
166
- 'OP LKnee', # 14
167
- 'OP RKnee', # 15
168
- 'OP LAnkle', # 16
169
- 'OP RAnkle', # 17
170
- 'OP LBigToe', # 18
171
- 'OP LSmallToe', # 19
172
- 'OP LHeel', # 20
173
- 'OP RBigToe', # 21
174
- 'OP RSmallToe', # 22
175
- 'OP RHeel', # 23
176
- ]
177
-
178
-
179
- def get_insta_skeleton():
180
- return np.array(
181
- [
182
- [0 , 1],
183
- [1 , 2],
184
- [2 , 3],
185
- [3 , 4],
186
- [4 , 5],
187
- [6 , 7],
188
- [7 , 8],
189
- [8 , 9],
190
- [9 ,10],
191
- [2 , 8],
192
- [3 , 9],
193
- [10,11],
194
- [8 ,12],
195
- [9 ,12],
196
- [12,13],
197
- [12,14],
198
- [14,15],
199
- [14,16],
200
- [15,17],
201
- [16,18],
202
- [0 ,20],
203
- [20,22],
204
- [5 ,19],
205
- [19,21],
206
- [5 ,23],
207
- [0 ,24],
208
- ])
209
-
210
-
211
- def get_staf_skeleton():
212
- return np.array(
213
- [
214
- [0, 1],
215
- [1, 2],
216
- [2, 3],
217
- [3, 4],
218
- [1, 5],
219
- [5, 6],
220
- [6, 7],
221
- [1, 8],
222
- [8, 9],
223
- [9, 10],
224
- [10, 11],
225
- [8, 12],
226
- [12, 13],
227
- [13, 14],
228
- [0, 15],
229
- [0, 16],
230
- [15, 17],
231
- [16, 18],
232
- [2, 9],
233
- [5, 12],
234
- [1, 19],
235
- [20, 19],
236
- ]
237
- )
238
-
239
-
240
- def get_staf_joint_names():
241
- return [
242
- 'OP Nose', # 0,
243
- 'OP Neck', # 1,
244
- 'OP RShoulder', # 2,
245
- 'OP RElbow', # 3,
246
- 'OP RWrist', # 4,
247
- 'OP LShoulder', # 5,
248
- 'OP LElbow', # 6,
249
- 'OP LWrist', # 7,
250
- 'OP MidHip', # 8,
251
- 'OP RHip', # 9,
252
- 'OP RKnee', # 10,
253
- 'OP RAnkle', # 11,
254
- 'OP LHip', # 12,
255
- 'OP LKnee', # 13,
256
- 'OP LAnkle', # 14,
257
- 'OP REye', # 15,
258
- 'OP LEye', # 16,
259
- 'OP REar', # 17,
260
- 'OP LEar', # 18,
261
- 'Neck (LSP)', # 19,
262
- 'Top of Head (LSP)', # 20,
263
- ]
264
-
265
-
266
- def get_spin_op_joint_names():
267
- return [
268
- 'OP Nose', # 0
269
- 'OP Neck', # 1
270
- 'OP RShoulder', # 2
271
- 'OP RElbow', # 3
272
- 'OP RWrist', # 4
273
- 'OP LShoulder', # 5
274
- 'OP LElbow', # 6
275
- 'OP LWrist', # 7
276
- 'OP MidHip', # 8
277
- 'OP RHip', # 9
278
- 'OP RKnee', # 10
279
- 'OP RAnkle', # 11
280
- 'OP LHip', # 12
281
- 'OP LKnee', # 13
282
- 'OP LAnkle', # 14
283
- 'OP REye', # 15
284
- 'OP LEye', # 16
285
- 'OP REar', # 17
286
- 'OP LEar', # 18
287
- 'OP LBigToe', # 19
288
- 'OP LSmallToe', # 20
289
- 'OP LHeel', # 21
290
- 'OP RBigToe', # 22
291
- 'OP RSmallToe', # 23
292
- 'OP RHeel', # 24
293
- ]
294
-
295
-
296
- def get_openpose_joint_names():
297
- return [
298
- 'OP Nose', # 0
299
- 'OP Neck', # 1
300
- 'OP RShoulder', # 2
301
- 'OP RElbow', # 3
302
- 'OP RWrist', # 4
303
- 'OP LShoulder', # 5
304
- 'OP LElbow', # 6
305
- 'OP LWrist', # 7
306
- 'OP MidHip', # 8
307
- 'OP RHip', # 9
308
- 'OP RKnee', # 10
309
- 'OP RAnkle', # 11
310
- 'OP LHip', # 12
311
- 'OP LKnee', # 13
312
- 'OP LAnkle', # 14
313
- 'OP REye', # 15
314
- 'OP LEye', # 16
315
- 'OP REar', # 17
316
- 'OP LEar', # 18
317
- 'OP LBigToe', # 19
318
- 'OP LSmallToe', # 20
319
- 'OP LHeel', # 21
320
- 'OP RBigToe', # 22
321
- 'OP RSmallToe', # 23
322
- 'OP RHeel', # 24
323
- ]
324
-
325
-
326
- def get_spin_joint_names():
327
- return [
328
- 'OP Nose', # 0
329
- 'OP Neck', # 1
330
- 'OP RShoulder', # 2
331
- 'OP RElbow', # 3
332
- 'OP RWrist', # 4
333
- 'OP LShoulder', # 5
334
- 'OP LElbow', # 6
335
- 'OP LWrist', # 7
336
- 'OP MidHip', # 8
337
- 'OP RHip', # 9
338
- 'OP RKnee', # 10
339
- 'OP RAnkle', # 11
340
- 'OP LHip', # 12
341
- 'OP LKnee', # 13
342
- 'OP LAnkle', # 14
343
- 'OP REye', # 15
344
- 'OP LEye', # 16
345
- 'OP REar', # 17
346
- 'OP LEar', # 18
347
- 'OP LBigToe', # 19
348
- 'OP LSmallToe', # 20
349
- 'OP LHeel', # 21
350
- 'OP RBigToe', # 22
351
- 'OP RSmallToe', # 23
352
- 'OP RHeel', # 24
353
- 'rankle', # 25
354
- 'rknee', # 26
355
- 'rhip', # 27
356
- 'lhip', # 28
357
- 'lknee', # 29
358
- 'lankle', # 30
359
- 'rwrist', # 31
360
- 'relbow', # 32
361
- 'rshoulder', # 33
362
- 'lshoulder', # 34
363
- 'lelbow', # 35
364
- 'lwrist', # 36
365
- 'neck', # 37
366
- 'headtop', # 38
367
- 'hip', # 39 'Pelvis (MPII)', # 39
368
- 'thorax', # 40 'Thorax (MPII)', # 40
369
- 'Spine (H36M)', # 41
370
- 'Jaw (H36M)', # 42
371
- 'Head (H36M)', # 43
372
- 'nose', # 44
373
- 'leye', # 45 'Left Eye', # 45
374
- 'reye', # 46 'Right Eye', # 46
375
- 'lear', # 47 'Left Ear', # 47
376
- 'rear', # 48 'Right Ear', # 48
377
- ]
378
-
379
- def get_muco3dhp_joint_names():
380
- return [
381
- 'headtop',
382
- 'thorax',
383
- 'rshoulder',
384
- 'relbow',
385
- 'rwrist',
386
- 'lshoulder',
387
- 'lelbow',
388
- 'lwrist',
389
- 'rhip',
390
- 'rknee',
391
- 'rankle',
392
- 'lhip',
393
- 'lknee',
394
- 'lankle',
395
- 'hip',
396
- 'Spine (H36M)',
397
- 'Head (H36M)',
398
- 'R_Hand',
399
- 'L_Hand',
400
- 'R_Toe',
401
- 'L_Toe'
402
- ]
403
-
404
- def get_h36m_joint_names():
405
- return [
406
- 'hip', # 0
407
- 'lhip', # 1
408
- 'lknee', # 2
409
- 'lankle', # 3
410
- 'rhip', # 4
411
- 'rknee', # 5
412
- 'rankle', # 6
413
- 'Spine (H36M)', # 7
414
- 'neck', # 8
415
- 'Head (H36M)', # 9
416
- 'headtop', # 10
417
- 'lshoulder', # 11
418
- 'lelbow', # 12
419
- 'lwrist', # 13
420
- 'rshoulder', # 14
421
- 'relbow', # 15
422
- 'rwrist', # 16
423
- ]
424
-
425
-
426
- def get_spin_skeleton():
427
- return np.array(
428
- [
429
- [0 , 1],
430
- [1 , 2],
431
- [2 , 3],
432
- [3 , 4],
433
- [1 , 5],
434
- [5 , 6],
435
- [6 , 7],
436
- [1 , 8],
437
- [8 , 9],
438
- [9 ,10],
439
- [10,11],
440
- [8 ,12],
441
- [12,13],
442
- [13,14],
443
- [0 ,15],
444
- [0 ,16],
445
- [15,17],
446
- [16,18],
447
- [21,19],
448
- [19,20],
449
- [14,21],
450
- [11,24],
451
- [24,22],
452
- [22,23],
453
- [0 ,38],
454
- ]
455
- )
456
-
457
-
458
- def get_openpose_skeleton():
459
- return np.array(
460
- [
461
- [0 , 1],
462
- [1 , 2],
463
- [2 , 3],
464
- [3 , 4],
465
- [1 , 5],
466
- [5 , 6],
467
- [6 , 7],
468
- [1 , 8],
469
- [8 , 9],
470
- [9 ,10],
471
- [10,11],
472
- [8 ,12],
473
- [12,13],
474
- [13,14],
475
- [0 ,15],
476
- [0 ,16],
477
- [15,17],
478
- [16,18],
479
- [21,19],
480
- [19,20],
481
- [14,21],
482
- [11,24],
483
- [24,22],
484
- [22,23],
485
- ]
486
- )
487
-
488
-
489
- def get_posetrack_joint_names():
490
- return [
491
- "nose",
492
- "neck",
493
- "headtop",
494
- "lear",
495
- "rear",
496
- "lshoulder",
497
- "rshoulder",
498
- "lelbow",
499
- "relbow",
500
- "lwrist",
501
- "rwrist",
502
- "lhip",
503
- "rhip",
504
- "lknee",
505
- "rknee",
506
- "lankle",
507
- "rankle"
508
- ]
509
-
510
-
511
- def get_posetrack_original_kp_names():
512
- return [
513
- 'nose',
514
- 'head_bottom',
515
- 'head_top',
516
- 'left_ear',
517
- 'right_ear',
518
- 'left_shoulder',
519
- 'right_shoulder',
520
- 'left_elbow',
521
- 'right_elbow',
522
- 'left_wrist',
523
- 'right_wrist',
524
- 'left_hip',
525
- 'right_hip',
526
- 'left_knee',
527
- 'right_knee',
528
- 'left_ankle',
529
- 'right_ankle'
530
- ]
531
-
532
-
533
- def get_pennaction_joint_names():
534
- return [
535
- "headtop", # 0
536
- "lshoulder", # 1
537
- "rshoulder", # 2
538
- "lelbow", # 3
539
- "relbow", # 4
540
- "lwrist", # 5
541
- "rwrist", # 6
542
- "lhip" , # 7
543
- "rhip" , # 8
544
- "lknee", # 9
545
- "rknee" , # 10
546
- "lankle", # 11
547
- "rankle" # 12
548
- ]
549
-
550
-
551
- def get_common_joint_names():
552
- return [
553
- "rankle", # 0 "lankle", # 0
554
- "rknee", # 1 "lknee", # 1
555
- "rhip", # 2 "lhip", # 2
556
- "lhip", # 3 "rhip", # 3
557
- "lknee", # 4 "rknee", # 4
558
- "lankle", # 5 "rankle", # 5
559
- "rwrist", # 6 "lwrist", # 6
560
- "relbow", # 7 "lelbow", # 7
561
- "rshoulder", # 8 "lshoulder", # 8
562
- "lshoulder", # 9 "rshoulder", # 9
563
- "lelbow", # 10 "relbow", # 10
564
- "lwrist", # 11 "rwrist", # 11
565
- "neck", # 12 "neck", # 12
566
- "headtop", # 13 "headtop", # 13
567
- ]
568
-
569
-
570
- def get_common_paper_joint_names():
571
- return [
572
- "Right Ankle", # 0 "lankle", # 0
573
- "Right Knee", # 1 "lknee", # 1
574
- "Right Hip", # 2 "lhip", # 2
575
- "Left Hip", # 3 "rhip", # 3
576
- "Left Knee", # 4 "rknee", # 4
577
- "Left Ankle", # 5 "rankle", # 5
578
- "Right Wrist", # 6 "lwrist", # 6
579
- "Right Elbow", # 7 "lelbow", # 7
580
- "Right Shoulder", # 8 "lshoulder", # 8
581
- "Left Shoulder", # 9 "rshoulder", # 9
582
- "Left Elbow", # 10 "relbow", # 10
583
- "Left Wrist", # 11 "rwrist", # 11
584
- "Neck", # 12 "neck", # 12
585
- "Head", # 13 "headtop", # 13
586
- ]
587
-
588
-
589
- def get_common_skeleton():
590
- return np.array(
591
- [
592
- [ 0, 1 ],
593
- [ 1, 2 ],
594
- [ 3, 4 ],
595
- [ 4, 5 ],
596
- [ 6, 7 ],
597
- [ 7, 8 ],
598
- [ 8, 2 ],
599
- [ 8, 9 ],
600
- [ 9, 3 ],
601
- [ 2, 3 ],
602
- [ 8, 12],
603
- [ 9, 10],
604
- [12, 9 ],
605
- [10, 11],
606
- [12, 13],
607
- ]
608
- )
609
-
610
-
611
- def get_coco_joint_names():
612
- return [
613
- "nose", # 0
614
- "leye", # 1
615
- "reye", # 2
616
- "lear", # 3
617
- "rear", # 4
618
- "lshoulder", # 5
619
- "rshoulder", # 6
620
- "lelbow", # 7
621
- "relbow", # 8
622
- "lwrist", # 9
623
- "rwrist", # 10
624
- "lhip", # 11
625
- "rhip", # 12
626
- "lknee", # 13
627
- "rknee", # 14
628
- "lankle", # 15
629
- "rankle", # 16
630
- ]
631
-
632
-
633
- def get_ochuman_joint_names():
634
- return [
635
- 'rshoulder',
636
- 'relbow',
637
- 'rwrist',
638
- 'lshoulder',
639
- 'lelbow',
640
- 'lwrist',
641
- 'rhip',
642
- 'rknee',
643
- 'rankle',
644
- 'lhip',
645
- 'lknee',
646
- 'lankle',
647
- 'headtop',
648
- 'neck',
649
- 'rear',
650
- 'lear',
651
- 'nose',
652
- 'reye',
653
- 'leye'
654
- ]
655
-
656
-
657
- def get_crowdpose_joint_names():
658
- return [
659
- 'lshoulder',
660
- 'rshoulder',
661
- 'lelbow',
662
- 'relbow',
663
- 'lwrist',
664
- 'rwrist',
665
- 'lhip',
666
- 'rhip',
667
- 'lknee',
668
- 'rknee',
669
- 'lankle',
670
- 'rankle',
671
- 'headtop',
672
- 'neck'
673
- ]
674
-
675
- def get_coco_skeleton():
676
- # 0 - nose,
677
- # 1 - leye,
678
- # 2 - reye,
679
- # 3 - lear,
680
- # 4 - rear,
681
- # 5 - lshoulder,
682
- # 6 - rshoulder,
683
- # 7 - lelbow,
684
- # 8 - relbow,
685
- # 9 - lwrist,
686
- # 10 - rwrist,
687
- # 11 - lhip,
688
- # 12 - rhip,
689
- # 13 - lknee,
690
- # 14 - rknee,
691
- # 15 - lankle,
692
- # 16 - rankle,
693
- return np.array(
694
- [
695
- [15, 13],
696
- [13, 11],
697
- [16, 14],
698
- [14, 12],
699
- [11, 12],
700
- [ 5, 11],
701
- [ 6, 12],
702
- [ 5, 6 ],
703
- [ 5, 7 ],
704
- [ 6, 8 ],
705
- [ 7, 9 ],
706
- [ 8, 10],
707
- [ 1, 2 ],
708
- [ 0, 1 ],
709
- [ 0, 2 ],
710
- [ 1, 3 ],
711
- [ 2, 4 ],
712
- [ 3, 5 ],
713
- [ 4, 6 ]
714
- ]
715
- )
716
-
717
-
718
- def get_mpii_joint_names():
719
- return [
720
- "rankle", # 0
721
- "rknee", # 1
722
- "rhip", # 2
723
- "lhip", # 3
724
- "lknee", # 4
725
- "lankle", # 5
726
- "hip", # 6
727
- "thorax", # 7
728
- "neck", # 8
729
- "headtop", # 9
730
- "rwrist", # 10
731
- "relbow", # 11
732
- "rshoulder", # 12
733
- "lshoulder", # 13
734
- "lelbow", # 14
735
- "lwrist", # 15
736
- ]
737
-
738
-
739
- def get_mpii_skeleton():
740
- # 0 - rankle,
741
- # 1 - rknee,
742
- # 2 - rhip,
743
- # 3 - lhip,
744
- # 4 - lknee,
745
- # 5 - lankle,
746
- # 6 - hip,
747
- # 7 - thorax,
748
- # 8 - neck,
749
- # 9 - headtop,
750
- # 10 - rwrist,
751
- # 11 - relbow,
752
- # 12 - rshoulder,
753
- # 13 - lshoulder,
754
- # 14 - lelbow,
755
- # 15 - lwrist,
756
- return np.array(
757
- [
758
- [ 0, 1 ],
759
- [ 1, 2 ],
760
- [ 2, 6 ],
761
- [ 6, 3 ],
762
- [ 3, 4 ],
763
- [ 4, 5 ],
764
- [ 6, 7 ],
765
- [ 7, 8 ],
766
- [ 8, 9 ],
767
- [ 7, 12],
768
- [12, 11],
769
- [11, 10],
770
- [ 7, 13],
771
- [13, 14],
772
- [14, 15]
773
- ]
774
- )
775
-
776
-
777
- def get_aich_joint_names():
778
- return [
779
- "rshoulder", # 0
780
- "relbow", # 1
781
- "rwrist", # 2
782
- "lshoulder", # 3
783
- "lelbow", # 4
784
- "lwrist", # 5
785
- "rhip", # 6
786
- "rknee", # 7
787
- "rankle", # 8
788
- "lhip", # 9
789
- "lknee", # 10
790
- "lankle", # 11
791
- "headtop", # 12
792
- "neck", # 13
793
- ]
794
-
795
-
796
- def get_aich_skeleton():
797
- # 0 - rshoulder,
798
- # 1 - relbow,
799
- # 2 - rwrist,
800
- # 3 - lshoulder,
801
- # 4 - lelbow,
802
- # 5 - lwrist,
803
- # 6 - rhip,
804
- # 7 - rknee,
805
- # 8 - rankle,
806
- # 9 - lhip,
807
- # 10 - lknee,
808
- # 11 - lankle,
809
- # 12 - headtop,
810
- # 13 - neck,
811
- return np.array(
812
- [
813
- [ 0, 1 ],
814
- [ 1, 2 ],
815
- [ 3, 4 ],
816
- [ 4, 5 ],
817
- [ 6, 7 ],
818
- [ 7, 8 ],
819
- [ 9, 10],
820
- [10, 11],
821
- [12, 13],
822
- [13, 0 ],
823
- [13, 3 ],
824
- [ 0, 6 ],
825
- [ 3, 9 ]
826
- ]
827
- )
828
-
829
-
830
- def get_3dpw_joint_names():
831
- return [
832
- "nose", # 0
833
- "thorax", # 1
834
- "rshoulder", # 2
835
- "relbow", # 3
836
- "rwrist", # 4
837
- "lshoulder", # 5
838
- "lelbow", # 6
839
- "lwrist", # 7
840
- "rhip", # 8
841
- "rknee", # 9
842
- "rankle", # 10
843
- "lhip", # 11
844
- "lknee", # 12
845
- "lankle", # 13
846
- ]
847
-
848
-
849
- def get_3dpw_skeleton():
850
- return np.array(
851
- [
852
- [ 0, 1 ],
853
- [ 1, 2 ],
854
- [ 2, 3 ],
855
- [ 3, 4 ],
856
- [ 1, 5 ],
857
- [ 5, 6 ],
858
- [ 6, 7 ],
859
- [ 2, 8 ],
860
- [ 5, 11],
861
- [ 8, 11],
862
- [ 8, 9 ],
863
- [ 9, 10],
864
- [11, 12],
865
- [12, 13]
866
- ]
867
- )
868
-
869
-
870
- def get_smplcoco_joint_names():
871
- return [
872
- "rankle", # 0
873
- "rknee", # 1
874
- "rhip", # 2
875
- "lhip", # 3
876
- "lknee", # 4
877
- "lankle", # 5
878
- "rwrist", # 6
879
- "relbow", # 7
880
- "rshoulder", # 8
881
- "lshoulder", # 9
882
- "lelbow", # 10
883
- "lwrist", # 11
884
- "neck", # 12
885
- "headtop", # 13
886
- "nose", # 14
887
- "leye", # 15
888
- "reye", # 16
889
- "lear", # 17
890
- "rear", # 18
891
- ]
892
-
893
-
894
- def get_smplcoco_skeleton():
895
- return np.array(
896
- [
897
- [ 0, 1 ],
898
- [ 1, 2 ],
899
- [ 3, 4 ],
900
- [ 4, 5 ],
901
- [ 6, 7 ],
902
- [ 7, 8 ],
903
- [ 8, 12],
904
- [12, 9 ],
905
- [ 9, 10],
906
- [10, 11],
907
- [12, 13],
908
- [14, 15],
909
- [15, 17],
910
- [16, 18],
911
- [14, 16],
912
- [ 8, 2 ],
913
- [ 9, 3 ],
914
- [ 2, 3 ],
915
- ]
916
- )
917
-
918
-
919
- def get_smpl_joint_names():
920
- return [
921
- 'hips', # 0
922
- 'leftUpLeg', # 1
923
- 'rightUpLeg', # 2
924
- 'spine', # 3
925
- 'leftLeg', # 4
926
- 'rightLeg', # 5
927
- 'spine1', # 6
928
- 'leftFoot', # 7
929
- 'rightFoot', # 8
930
- 'spine2', # 9
931
- 'leftToeBase', # 10
932
- 'rightToeBase', # 11
933
- 'neck', # 12
934
- 'leftShoulder', # 13
935
- 'rightShoulder', # 14
936
- 'head', # 15
937
- 'leftArm', # 16
938
- 'rightArm', # 17
939
- 'leftForeArm', # 18
940
- 'rightForeArm', # 19
941
- 'leftHand', # 20
942
- 'rightHand', # 21
943
- 'leftHandIndex1', # 22
944
- 'rightHandIndex1', # 23
945
- ]
946
-
947
-
948
- def get_smpl_paper_joint_names():
949
- return [
950
- 'Hips', # 0
951
- 'Left Hip', # 1
952
- 'Right Hip', # 2
953
- 'Spine', # 3
954
- 'Left Knee', # 4
955
- 'Right Knee', # 5
956
- 'Spine_1', # 6
957
- 'Left Ankle', # 7
958
- 'Right Ankle', # 8
959
- 'Spine_2', # 9
960
- 'Left Toe', # 10
961
- 'Right Toe', # 11
962
- 'Neck', # 12
963
- 'Left Shoulder', # 13
964
- 'Right Shoulder', # 14
965
- 'Head', # 15
966
- 'Left Arm', # 16
967
- 'Right Arm', # 17
968
- 'Left Elbow', # 18
969
- 'Right Elbow', # 19
970
- 'Left Hand', # 20
971
- 'Right Hand', # 21
972
- 'Left Thumb', # 22
973
- 'Right Thumb', # 23
974
- ]
975
-
976
-
977
- def get_smpl_neighbor_triplets():
978
- return [
979
- [ 0, 1, 2 ], # 0
980
- [ 1, 4, 0 ], # 1
981
- [ 2, 0, 5 ], # 2
982
- [ 3, 0, 6 ], # 3
983
- [ 4, 7, 1 ], # 4
984
- [ 5, 2, 8 ], # 5
985
- [ 6, 3, 9 ], # 6
986
- [ 7, 10, 4 ], # 7
987
- [ 8, 5, 11], # 8
988
- [ 9, 13, 14], # 9
989
- [10, 7, 4 ], # 10
990
- [11, 8, 5 ], # 11
991
- [12, 9, 15], # 12
992
- [13, 16, 9 ], # 13
993
- [14, 9, 17], # 14
994
- [15, 9, 12], # 15
995
- [16, 18, 13], # 16
996
- [17, 14, 19], # 17
997
- [18, 20, 16], # 18
998
- [19, 17, 21], # 19
999
- [20, 22, 18], # 20
1000
- [21, 19, 23], # 21
1001
- [22, 20, 18], # 22
1002
- [23, 19, 21], # 23
1003
- ]
1004
-
1005
-
1006
- def get_smpl_skeleton():
1007
- return np.array(
1008
- [
1009
- [ 0, 1 ],
1010
- [ 0, 2 ],
1011
- [ 0, 3 ],
1012
- [ 1, 4 ],
1013
- [ 2, 5 ],
1014
- [ 3, 6 ],
1015
- [ 4, 7 ],
1016
- [ 5, 8 ],
1017
- [ 6, 9 ],
1018
- [ 7, 10],
1019
- [ 8, 11],
1020
- [ 9, 12],
1021
- [ 9, 13],
1022
- [ 9, 14],
1023
- [12, 15],
1024
- [13, 16],
1025
- [14, 17],
1026
- [16, 18],
1027
- [17, 19],
1028
- [18, 20],
1029
- [19, 21],
1030
- [20, 22],
1031
- [21, 23],
1032
- ]
1033
- )
1034
-
1035
-
1036
- def map_spin_joints_to_smpl():
1037
- # this function primarily will be used to copy 2D keypoint
1038
- # confidences to pose parameters
1039
- return [
1040
- [(39, 27, 28), 0], # hip,lhip,rhip->hips
1041
- [(28,), 1], # lhip->leftUpLeg
1042
- [(27,), 2], # rhip->rightUpLeg
1043
- [(41, 27, 28, 39), 3], # Spine->spine
1044
- [(29,), 4], # lknee->leftLeg
1045
- [(26,), 5], # rknee->rightLeg
1046
- [(41, 40, 33, 34,), 6], # spine, thorax ->spine1
1047
- [(30,), 7], # lankle->leftFoot
1048
- [(25,), 8], # rankle->rightFoot
1049
- [(40, 33, 34), 9], # thorax,shoulders->spine2
1050
- [(30,), 10], # lankle -> leftToe
1051
- [(25,), 11], # rankle -> rightToe
1052
- [(37, 42, 33, 34), 12], # neck, shoulders -> neck
1053
- [(34,), 13], # lshoulder->leftShoulder
1054
- [(33,), 14], # rshoulder->rightShoulder
1055
- [(33, 34, 38, 43, 44, 45, 46, 47, 48,), 15], # nose, eyes, ears, headtop, shoulders->head
1056
- [(34,), 16], # lshoulder->leftArm
1057
- [(33,), 17], # rshoulder->rightArm
1058
- [(35,), 18], # lelbow->leftForeArm
1059
- [(32,), 19], # relbow->rightForeArm
1060
- [(36,), 20], # lwrist->leftHand
1061
- [(31,), 21], # rwrist->rightHand
1062
- [(36,), 22], # lhand -> leftHandIndex
1063
- [(31,), 23], # rhand -> rightHandIndex
1064
- ]
1065
-
1066
-
1067
- def map_smpl_to_common():
1068
- return [
1069
- [(11, 8), 0], # rightToe, rightFoot -> rankle
1070
- [(5,), 1], # rightleg -> rknee,
1071
- [(2,), 2], # rhip
1072
- [(1,), 3], # lhip
1073
- [(4,), 4], # leftLeg -> lknee
1074
- [(10, 7), 5], # lefttoe, leftfoot -> lankle
1075
- [(21, 23), 6], # rwrist
1076
- [(18,), 7], # relbow
1077
- [(17, 14), 8], # rshoulder
1078
- [(16, 13), 9], # lshoulder
1079
- [(19,), 10], # lelbow
1080
- [(20, 22), 11], # lwrist
1081
- [(0, 3, 6, 9, 12), 12], # neck
1082
- [(15,), 13], # headtop
1083
- ]
1084
-
1085
-
1086
- def relation_among_spin_joints():
1087
- # this function primarily will be used to copy 2D keypoint
1088
- # confidences to 3D joints
1089
- return [
1090
- [(), 25],
1091
- [(), 26],
1092
- [(39,), 27],
1093
- [(39,), 28],
1094
- [(), 29],
1095
- [(), 30],
1096
- [(), 31],
1097
- [(), 32],
1098
- [(), 33],
1099
- [(), 34],
1100
- [(), 35],
1101
- [(), 36],
1102
- [(40,42,44,43,38,33,34,), 37],
1103
- [(43,44,45,46,47,48,33,34,), 38],
1104
- [(27,28,), 39],
1105
- [(27,28,37,41,42,), 40],
1106
- [(27,28,39,40,), 41],
1107
- [(37,38,44,45,46,47,48,), 42],
1108
- [(44,45,46,47,48,38,42,37,33,34,), 43],
1109
- [(44,45,46,47,48,38,42,37,33,34), 44],
1110
- [(44,45,46,47,48,38,42,37,33,34), 45],
1111
- [(44,45,46,47,48,38,42,37,33,34), 46],
1112
- [(44,45,46,47,48,38,42,37,33,34), 47],
1113
- [(44,45,46,47,48,38,42,37,33,34), 48],
1114
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/loss.py DELETED
@@ -1,207 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from common import constants
4
- from models.smpl import SMPL
5
- from smplx import SMPLX
6
- import pickle as pkl
7
- import numpy as np
8
- from utils.mesh_utils import save_results_mesh
9
- from utils.diff_renderer import Pytorch3D
10
- import os
11
- import cv2
12
-
13
-
14
- class sem_loss_function(nn.Module):
15
- def __init__(self):
16
- super(sem_loss_function, self).__init__()
17
- self.ce = nn.BCELoss()
18
-
19
- def forward(self, y_true, y_pred):
20
- loss = self.ce(y_pred, y_true)
21
- return loss
22
-
23
-
24
- class class_loss_function(nn.Module):
25
- def __init__(self):
26
- super(class_loss_function, self).__init__()
27
- self.ce_loss = nn.BCELoss()
28
- # self.ce_loss = nn.MultiLabelSoftMarginLoss()
29
- # self.ce_loss = nn.MultiLabelMarginLoss()
30
-
31
- def forward(self, y_true, y_pred, valid_mask):
32
- # y_true = torch.squeeze(y_true, 1).long()
33
- # y_true = torch.squeeze(y_true, 1)
34
- # y_pred = torch.squeeze(y_pred, 1)
35
- bs = y_true.shape[0]
36
- if bs != 1:
37
- y_pred = y_pred[valid_mask == 1]
38
- y_true = y_true[valid_mask == 1]
39
- if len(y_pred) > 0:
40
- return self.ce_loss(y_pred, y_true)
41
- else:
42
- return torch.tensor(0.0).to(y_pred.device)
43
-
44
-
45
- class pixel_anchoring_function(nn.Module):
46
- def __init__(self, model_type, device='cuda'):
47
- super(pixel_anchoring_function, self).__init__()
48
-
49
- self.device = device
50
-
51
- self.model_type = model_type
52
-
53
- if self.model_type == 'smplx':
54
- # load mapping from smpl vertices to smplx vertices
55
- mapping_pkl = os.path.join(constants.CONTACT_MAPPING_PATH, "smpl_to_smplx.pkl")
56
- with open(mapping_pkl, 'rb') as f:
57
- smpl_to_smplx_mapping = pkl.load(f)
58
- smpl_to_smplx_mapping = smpl_to_smplx_mapping["matrix"]
59
- self.smpl_to_smplx_mapping = torch.from_numpy(smpl_to_smplx_mapping).float().to(self.device)
60
-
61
-
62
- # Setup the SMPL model
63
- if self.model_type == 'smpl':
64
- self.n_vertices = 6890
65
- self.body_model = SMPL(constants.SMPL_MODEL_DIR).to(self.device)
66
- if self.model_type == 'smplx':
67
- self.n_vertices = 10475
68
- self.body_model = SMPLX(constants.SMPLX_MODEL_DIR,
69
- num_betas=10,
70
- use_pca=False).to(self.device)
71
- self.body_faces = torch.LongTensor(self.body_model.faces.astype(np.int32)).to(self.device)
72
-
73
- self.ce_loss = nn.BCELoss()
74
-
75
- def get_posed_mesh(self, body_params, debug=False):
76
- betas = body_params['betas']
77
- pose = body_params['pose']
78
- transl = body_params['transl']
79
-
80
- # extra smplx params
81
- extra_args = {'jaw_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
82
- 'leye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
83
- 'reye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
84
- 'expression': torch.zeros((betas.shape[0], 10)).float().to(self.device),
85
- 'left_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device),
86
- 'right_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device)}
87
-
88
- smpl_output = self.body_model(betas=betas,
89
- body_pose=pose[:, 3:],
90
- global_orient=pose[:, :3],
91
- pose2rot=True,
92
- transl=transl,
93
- **extra_args)
94
- smpl_verts = smpl_output.vertices
95
- smpl_joints = smpl_output.joints
96
-
97
- if debug:
98
- for mesh_i in range(smpl_verts.shape[0]):
99
- out_dir = 'temp_meshes'
100
- os.makedirs(out_dir, exist_ok=True)
101
- out_file = os.path.join(out_dir, f'temp_mesh_{mesh_i:04d}.obj')
102
- save_results_mesh(smpl_verts[mesh_i], self.body_model.faces, out_file)
103
- return smpl_verts, smpl_joints
104
-
105
-
106
- def render_batch(self, smpl_verts, cam_k, img_scale_factor, vertex_colors=None, face_textures=None, debug=False):
107
-
108
- bs = smpl_verts.shape[0]
109
-
110
- # Incorporate resizing factor into the camera
111
- img_w = 256 # TODO: Remove hardcoding
112
- img_h = 256 # TODO: Remove hardcoding
113
- focal_length_x = cam_k[:, 0, 0] * img_scale_factor[:, 0]
114
- focal_length_y = cam_k[:, 1, 1] * img_scale_factor[:, 1]
115
- # convert to float for pytorch3d
116
- focal_length_x, focal_length_y = focal_length_x.float(), focal_length_y.float()
117
-
118
- # concatenate focal length
119
- focal_length = torch.stack([focal_length_x, focal_length_y], dim=1)
120
-
121
- # Setup renderer
122
- renderer = Pytorch3D(img_h=img_h,
123
- img_w=img_w,
124
- focal_length=focal_length,
125
- smpl_faces=self.body_faces,
126
- texture_mode='deco',
127
- vertex_colors=vertex_colors,
128
- face_textures=face_textures,
129
- is_train=True,
130
- is_cam_batch=True)
131
- front_view = renderer(smpl_verts)
132
- if debug:
133
- # visualize the front view as images in a temp_image folder
134
- for i in range(bs):
135
- front_view_rgb = front_view[i, :3, :, :].permute(1, 2, 0).detach().cpu()
136
- front_view_mask = front_view[i, 3, :, :].detach().cpu()
137
- out_dir = 'temp_images'
138
- os.makedirs(out_dir, exist_ok=True)
139
- out_file_rgb = os.path.join(out_dir, f'{i:04d}_rgb.png')
140
- out_file_mask = os.path.join(out_dir, f'{i:04d}_mask.png')
141
- cv2.imwrite(out_file_rgb, front_view_rgb.numpy()*255)
142
- cv2.imwrite(out_file_mask, front_view_mask.numpy()*255)
143
-
144
- return front_view
145
-
146
- def paint_contact(self, pred_contact):
147
- """
148
- Paints the contact vertices on the SMPL mesh
149
-
150
- Args:
151
- pred_contact: prbabilities of contact vertices
152
-
153
- Returns:
154
- pred_rgb: RGB colors for the contact vertices
155
- """
156
- bs = pred_contact.shape[0]
157
-
158
- # initialize black and while colors
159
- colors = torch.tensor([[0, 0, 0], [1, 1, 1]]).float().to(self.device)
160
- colors = torch.unsqueeze(colors, 0).expand(bs, -1, -1)
161
-
162
- # add another dimension to the contact probabilities for inverse probabilities
163
- pred_contact = torch.unsqueeze(pred_contact, 2)
164
- pred_contact = torch.cat((1 - pred_contact, pred_contact), 2)
165
-
166
- # get pred_rgb colors
167
- pred_vert_rgb = torch.bmm(pred_contact, colors)
168
- pred_face_rgb = pred_vert_rgb[:, self.body_faces, :][:, :, 0, :] # take the first vertex color
169
- pred_face_texture = torch.zeros((bs, self.body_faces.shape[0], 1, 1, 3), dtype=torch.float32).to(self.device)
170
- pred_face_texture[:, :, 0, 0, :] = pred_face_rgb
171
- return pred_vert_rgb, pred_face_texture
172
-
173
- def forward(self, pred_contact, body_params, cam_k, img_scale_factor, gt_contact_polygon, valid_mask):
174
- """
175
- Takes predicted contact labels (probabilities), transfers them to the posed mesh and
176
- renders to the image. Loss is computed between the rendered contact and the ground truth
177
- polygons from HOT.
178
-
179
- Args:
180
- pred_contact: predicted contact labels (probabilities)
181
- body_params: SMPL parameters in camera coords
182
- cam_k: camera intrinsics
183
- gt_contact_polygon: ground truth polygons from HOT
184
- """
185
- # convert pred_contact to smplx
186
- bs = pred_contact.shape[0]
187
- if self.model_type == 'smplx':
188
- smpl_to_smplx_mapping = self.smpl_to_smplx_mapping[None].expand(bs, -1, -1)
189
- pred_contact = torch.bmm(smpl_to_smplx_mapping, pred_contact[..., None])
190
- pred_contact = pred_contact.squeeze()
191
-
192
- # get the posed mesh
193
- smpl_verts, smpl_joints = self.get_posed_mesh(body_params)
194
-
195
- # paint the contact vertices on the mesh
196
- vertex_colors, face_textures = self.paint_contact(pred_contact)
197
-
198
- # render the mesh
199
- front_view = self.render_batch(smpl_verts, cam_k, img_scale_factor, vertex_colors, face_textures)
200
- front_view_rgb = front_view[:, :3, :, :].permute(0, 2, 3, 1)
201
- front_view_mask = front_view[:, 3, :, :]
202
-
203
- # compute segmentation loss between rendered contact mask and ground truth contact mask
204
- front_view_rgb = front_view_rgb[valid_mask == 1]
205
- gt_contact_polygon = gt_contact_polygon[valid_mask == 1]
206
- loss = self.ce_loss(front_view_rgb, gt_contact_polygon)
207
- return loss, front_view_rgb, front_view_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/mesh_utils.py DELETED
@@ -1,6 +0,0 @@
1
- import trimesh
2
-
3
- def save_results_mesh(vertices, faces, filename):
4
- mesh = trimesh.Trimesh(vertices, faces, process=False)
5
- mesh.export(filename)
6
- print(f'save results to {filename}')
 
 
 
 
 
 
 
utils/metrics.py DELETED
@@ -1,106 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import monai.metrics as metrics
4
- from common.constants import DIST_MATRIX_PATH
5
-
6
- DIST_MATRIX = np.load(DIST_MATRIX_PATH)
7
-
8
- def metric(mask, pred, back=True):
9
- iou = metrics.compute_meaniou(pred, mask, back, False)
10
- iou = iou.mean()
11
-
12
- return iou
13
-
14
- def precision_recall_f1score(gt, pred):
15
- """
16
- Compute precision, recall, and f1
17
- """
18
-
19
- # gt = gt.numpy()
20
- # pred = pred.numpy()
21
-
22
- precision = torch.zeros(gt.shape[0])
23
- recall = torch.zeros(gt.shape[0])
24
- f1 = torch.zeros(gt.shape[0])
25
-
26
- for b in range(gt.shape[0]):
27
- tp_num = gt[b, pred[b, :] >= 0.5].sum()
28
- precision_denominator = (pred[b, :] >= 0.5).sum()
29
- recall_denominator = (gt[b, :]).sum()
30
-
31
- precision_ = tp_num / precision_denominator
32
- recall_ = tp_num / recall_denominator
33
- if precision_denominator == 0: # if no pred
34
- precision_ = 1.
35
- recall_ = 0.
36
- f1_ = 0.
37
- elif recall_denominator == 0: # if no GT
38
- precision_ = 0.
39
- recall_ = 1.
40
- f1_ = 0.
41
- elif (precision_ + recall_) <= 1e-10: # to avoid precision issues
42
- precision_= 0.
43
- recall_= 0.
44
- f1_ = 0.
45
- else:
46
- f1_ = 2 * precision_ * recall_ / (precision_ + recall_)
47
-
48
- precision[b] = precision_
49
- recall[b] = recall_
50
- f1[b] = f1_
51
-
52
- # return precision, recall, f1
53
- return precision, recall, f1
54
-
55
- def acc_precision_recall_f1score(gt, pred):
56
- """
57
- Compute acc, precision, recall, and f1
58
- """
59
-
60
- # gt = gt.numpy()
61
- # pred = pred.numpy()
62
-
63
- acc = torch.zeros(gt.shape[0])
64
- precision = torch.zeros(gt.shape[0])
65
- recall = torch.zeros(gt.shape[0])
66
- f1 = torch.zeros(gt.shape[0])
67
-
68
- for b in range(gt.shape[0]):
69
- tp_num = gt[b, pred[b, :] >= 0.5].sum()
70
- precision_denominator = (pred[b, :] >= 0.5).sum()
71
- recall_denominator = (gt[b, :]).sum()
72
- tn_num = gt.shape[-1] - precision_denominator - recall_denominator + tp_num
73
-
74
- acc_ = (tp_num + tn_num) / gt.shape[-1]
75
- precision_ = tp_num / (precision_denominator + 1e-10)
76
- recall_ = tp_num / (recall_denominator + 1e-10)
77
- f1_ = 2 * precision_ * recall_ / (precision_ + recall_ + 1e-10)
78
-
79
- acc[b] = acc_
80
- precision[b] = precision_
81
- recall[b] = recall_
82
-
83
- # return precision, recall, f1
84
- return acc, precision, recall, f1
85
-
86
- def det_error_metric(pred, gt):
87
-
88
- gt = gt.detach().cpu()
89
- pred = pred.detach().cpu()
90
-
91
- dist_matrix = torch.tensor(DIST_MATRIX)
92
-
93
- false_positive_dist = torch.zeros(gt.shape[0])
94
- false_negative_dist = torch.zeros(gt.shape[0])
95
-
96
- for b in range(gt.shape[0]):
97
- gt_columns = dist_matrix[:, gt[b, :]==1] if any(gt[b, :]==1) else dist_matrix
98
- error_matrix = gt_columns[pred[b, :] >= 0.5, :] if any(pred[b, :] >= 0.5) else gt_columns
99
-
100
- false_positive_dist_ = error_matrix.min(dim=1)[0].mean()
101
- false_negative_dist_ = error_matrix.min(dim=0)[0].mean()
102
-
103
- false_positive_dist[b] = false_positive_dist_
104
- false_negative_dist[b] = false_negative_dist_
105
-
106
- return false_positive_dist, false_negative_dist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/smpl_uv.py DELETED
@@ -1,167 +0,0 @@
1
- import torch
2
- import trimesh
3
- import numpy as np
4
- import skimage.io as io
5
- from PIL import Image
6
- from smplx import SMPL
7
- from matplotlib import cm as mpl_cm, colors as mpl_colors
8
- from trimesh.visual.color import face_to_vertex_color, vertex_to_face_color, to_rgba
9
-
10
- from common import constants
11
- from .colorwheel import make_color_wheel_image
12
-
13
-
14
- def get_smpl_uv():
15
- uv_obj = 'data/body_models/smpl_uv_20200910/smpl_uv.obj'
16
-
17
- uv_map = []
18
- with open(uv_obj) as f:
19
- for line in f.readlines():
20
- if line.startswith('vt'):
21
- coords = [float(x) for x in line.split(' ')[1:]]
22
- uv_map.append(coords)
23
-
24
- uv_map = np.array(uv_map)
25
-
26
- return uv_map
27
-
28
-
29
- def show_uv_texture():
30
- # image = io.imread('data/body_models/smpl_uv_20200910/smpl_uv_20200910.png')
31
- image = make_color_wheel_image(1024, 1024)
32
- image = Image.fromarray(image)
33
-
34
- uv = np.load('data/body_models/smpl_uv_20200910/uv_table.npy') # get_smpl_uv()
35
- material = trimesh.visual.texture.SimpleMaterial(image=image)
36
- tex_visuals = trimesh.visual.TextureVisuals(uv=uv, image=image, material=material)
37
-
38
- smpl = SMPL(constants.SMPL_MODEL_DIR)
39
-
40
- faces = smpl.faces
41
- verts = smpl().vertices[0].detach().numpy()
42
-
43
- # assert(len(uv) == len(verts))
44
- print(uv.shape)
45
- vc = tex_visuals.to_color().vertex_colors
46
- fc = trimesh.visual.color.vertex_to_face_color(vc, faces)
47
- face_colors = fc.copy()
48
- fc = fc.astype(float)
49
- vc = vc.astype(float)
50
- fc[:,:3] = fc[:,:3] / 255.
51
- vc[:,:3] = vc[:,:3] / 255.
52
- print(fc[:,:3].max(), fc[:,:3].min(), fc[:,:3].mean())
53
- print(vc[:, :3].max(), vc[:, :3].min(), vc[:, :3].mean())
54
- np.save('data/body_models/smpl/color_wheel_face_colors.npy', fc)
55
- np.save('data/body_models/smpl/color_wheel_vertex_colors.npy', vc)
56
- print(fc.shape)
57
- mesh = trimesh.Trimesh(verts, faces, validate=True, process=False, face_colors=face_colors)
58
- # mesh = trimesh.load('data/body_models/smpl_uv_20200910/smpl_uv.obj', process=False)
59
- # mesh.visual = tex_visuals
60
-
61
- # import ipdb; ipdb.set_trace()
62
- # print(vc.shape)
63
- mesh.show()
64
-
65
-
66
- def show_colored_mesh():
67
- cm = mpl_cm.get_cmap('jet')
68
- norm_gt = mpl_colors.Normalize()
69
-
70
- smpl = SMPL(constants.SMPL_MODEL_DIR)
71
-
72
- faces = smpl.faces
73
- verts = smpl().vertices[0].detach().numpy()
74
-
75
- m = trimesh.Trimesh(verts, faces, process=False)
76
-
77
- mode = 1
78
- if mode == 0:
79
- # mano_segm_labels = m.triangles_center
80
- face_labels = m.triangles_center
81
- face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
82
-
83
- elif mode == 1:
84
- # print(face_labels.shape)
85
- face_labels = m.triangles_center
86
- face_labels = np.argsort(np.linalg.norm(face_labels, axis=-1))
87
- face_colors = np.ones((13776, 4))
88
- face_colors[:, 3] = 1.0
89
- face_colors[:, :3] = cm(norm_gt(face_labels))[:, :3]
90
- elif mode == 2:
91
- # breakpoint()
92
- fc = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')[0, :, 0, 0, 0, :]
93
- face_colors = np.ones((13776, 4))
94
- face_colors[:, :3] = fc
95
- mesh = trimesh.Trimesh(verts, faces, process=False, face_colors=face_colors)
96
- mesh.show()
97
-
98
-
99
- def get_tenet_texture(mode='smplpix'):
100
- # mode = 'smplpix', 'decomr'
101
-
102
- smpl = SMPL(constants.SMPL_MODEL_DIR)
103
-
104
- faces = smpl.faces
105
- verts = smpl().vertices[0].detach().numpy()
106
-
107
- m = trimesh.Trimesh(verts, faces, process=False)
108
- if mode == 'smplpix':
109
- # mano_segm_labels = m.triangles_center
110
- face_labels = m.triangles_center
111
- face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
112
- texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
113
- texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
114
- texture = torch.from_numpy(texture).float()
115
- elif mode == 'decomr':
116
- texture = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')
117
- texture = torch.from_numpy(texture).float()
118
- elif mode == 'colorwheel':
119
- face_colors = np.load('data/body_models/smpl/color_wheel_face_colors.npy')
120
- texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
121
- texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
122
- texture = torch.from_numpy(texture).float()
123
- else:
124
- raise ValueError(f'{mode} is not defined!')
125
-
126
- return texture
127
-
128
-
129
- def save_tenet_textures(mode='smplpix'):
130
- # mode = 'smplpix', 'decomr'
131
-
132
- smpl = SMPL(constants.SMPL_MODEL_DIR)
133
-
134
- faces = smpl.faces
135
- verts = smpl().vertices[0].detach().numpy()
136
-
137
- m = trimesh.Trimesh(verts, faces, process=False)
138
-
139
- if mode == 'smplpix':
140
- # mano_segm_labels = m.triangles_center
141
- face_labels = m.triangles_center
142
- face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
143
- texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
144
- texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
145
- texture = torch.from_numpy(texture).float()
146
-
147
- vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
148
-
149
- elif mode == 'decomr':
150
- texture = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')
151
- texture = torch.from_numpy(texture).float()
152
- face_colors = texture[0, :, 0, 0, 0, :]
153
- vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
154
-
155
- elif mode == 'colorwheel':
156
- face_colors = np.load('data/body_models/smpl/color_wheel_face_colors.npy')
157
- texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
158
- texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
159
- texture = torch.from_numpy(texture).float()
160
- face_colors[:, :3] *= 255
161
- vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
162
- else:
163
- raise ValueError(f'{mode} is not defined!')
164
-
165
- print(vert_colors.shape, vert_colors.max())
166
- np.save(f'data/body_models/smpl/{mode}_vertex_colors.npy', vert_colors)
167
- return texture
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vis/__pycache__/visualize.cpython-37.pyc DELETED
Binary file (6.7 kB)
 
vis/visualize.py DELETED
@@ -1,209 +0,0 @@
1
- import cv2
2
- import os
3
- import trimesh
4
- import PIL.Image as pil_img
5
- import numpy as np
6
- import pyrender
7
- from common import constants
8
-
9
- os.environ['PYOPENGL_PLATFORM'] = 'egl'
10
-
11
- def render_image(scene, img_res, img=None, viewer=False):
12
- '''
13
- Render the given pyrender scene and return the image. Can also overlay the mesh on an image.
14
- '''
15
- if viewer:
16
- pyrender.Viewer(scene, use_raymond_lighting=True)
17
- return 0
18
- else:
19
- r = pyrender.OffscreenRenderer(viewport_width=img_res,
20
- viewport_height=img_res,
21
- point_size=1.0)
22
- color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)
23
- color = color.astype(np.float32) / 255.0
24
-
25
- if img is not None:
26
- valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
27
- input_img = img.detach().cpu().numpy()
28
- output_img = (color[:, :, :-1] * valid_mask +
29
- (1 - valid_mask) * input_img)
30
- else:
31
- output_img = color
32
- return output_img
33
-
34
- def create_scene(mesh, img, focal_length=500, camera_center=250, img_res=500):
35
- # Setup the scene
36
- scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0, 1.0],
37
- ambient_light=(0.3, 0.3, 0.3))
38
- # add mesh for camera
39
- camera_pose = np.eye(4)
40
- camera_rotation = np.eye(3, 3)
41
- camera_translation = np.array([0., 0, 2.5])
42
- camera_pose[:3, :3] = camera_rotation
43
- camera_pose[:3, 3] = camera_rotation @ camera_translation
44
- pyrencamera = pyrender.camera.IntrinsicsCamera(
45
- fx=focal_length, fy=focal_length,
46
- cx=camera_center, cy=camera_center)
47
- scene.add(pyrencamera, pose=camera_pose)
48
- # create and add light
49
- light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=1)
50
- light_pose = np.eye(4)
51
- for lp in [[1, 1, 1], [-1, 1, 1], [1, -1, 1], [-1, -1, 1]]:
52
- light_pose[:3, 3] = mesh.vertices.mean(0) + np.array(lp)
53
- # out_mesh.vertices.mean(0) + np.array(lp)
54
- scene.add(light, pose=light_pose)
55
- # add body mesh
56
- material = pyrender.MetallicRoughnessMaterial(
57
- metallicFactor=0.0,
58
- alphaMode='OPAQUE',
59
- baseColorFactor=(1.0, 1.0, 0.9, 1.0))
60
- mesh_images = []
61
-
62
- # resize input image to fit the mesh image height
63
- # print(img.shape)
64
- img_height = img_res
65
- img_width = int(img_height * img.shape[1] / img.shape[0])
66
- img = cv2.resize(img, (img_width, img_height))
67
- mesh_images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
68
-
69
- for sideview_angle in [0, 90, 180, 270]:
70
- out_mesh = mesh.copy()
71
- rot = trimesh.transformations.rotation_matrix(
72
- np.radians(sideview_angle), [0, 1, 0])
73
- out_mesh.apply_transform(rot)
74
- out_mesh = pyrender.Mesh.from_trimesh(
75
- out_mesh,
76
- material=material)
77
- mesh_pose = np.eye(4)
78
- scene.add(out_mesh, pose=mesh_pose, name='mesh')
79
- output_img = render_image(scene, img_res)
80
- output_img = pil_img.fromarray((output_img * 255).astype(np.uint8))
81
- output_img = np.asarray(output_img)[:, :, :3]
82
- mesh_images.append(output_img)
83
- # delete the previous mesh
84
- prev_mesh = scene.get_nodes(name='mesh').pop()
85
- scene.remove_node(prev_mesh)
86
-
87
- # show upside down view
88
- for topview_angle in [90, 270]:
89
- out_mesh = mesh.copy()
90
- rot = trimesh.transformations.rotation_matrix(
91
- np.radians(topview_angle), [1, 0, 0])
92
- out_mesh.apply_transform(rot)
93
- out_mesh = pyrender.Mesh.from_trimesh(
94
- out_mesh,
95
- material=material)
96
- mesh_pose = np.eye(4)
97
- scene.add(out_mesh, pose=mesh_pose, name='mesh')
98
- output_img = render_image(scene, img_res)
99
- output_img = pil_img.fromarray((output_img * 255).astype(np.uint8))
100
- output_img = np.asarray(output_img)[:, :, :3]
101
- mesh_images.append(output_img)
102
- # delete the previous mesh
103
- prev_mesh = scene.get_nodes(name='mesh').pop()
104
- scene.remove_node(prev_mesh)
105
-
106
- # stack images
107
- IMG = np.hstack(mesh_images)
108
- IMG = pil_img.fromarray(IMG)
109
- IMG.thumbnail((3000, 3000))
110
- return IMG
111
-
112
- # img = cv2.imread('../samples/prox_N3OpenArea_03301_01_s001_frame_00694.jpg')
113
- # mesh = trimesh.load('../samples/mesh.ply', process=False)
114
- # comb_img = create_scene(mesh, img)
115
- # comb_img.save('../samples/combined_image.png')
116
-
117
- def unsplit(img, palette):
118
- rgb_img = np.zeros((img.shape[0], img.shape[1], 3))
119
- for i in range(img.shape[0]):
120
- for j in range(img.shape[1]):
121
- id = np.argmax(img[i, j, :])
122
- rgb_img[i, j, :] = palette[id]
123
-
124
- return rgb_img
125
-
126
- def gen_render(output, normalize=True):
127
- img = output['img'].cpu().numpy()
128
- contact_labels_3d = output['contact_labels_3d_gt'].cpu().numpy()
129
- contact_labels_3d_pred = output['contact_labels_3d_pred'].cpu().numpy()
130
- sem_mask_gt = output['sem_mask_gt'].cpu().numpy()
131
- sem_mask_pred = output['sem_mask_pred'].cpu().numpy()
132
- part_mask_gt = output['part_mask_gt'].cpu().numpy()
133
- part_mask_pred = output['part_mask_pred'].cpu().numpy()
134
- contact_2d_gt_rgb = output['contact_2d_gt'].cpu().numpy()
135
- contact_2d_pred_rgb = output['contact_2d_pred_rgb'].cpu().numpy()
136
-
137
- mesh_path = './data/smpl/smpl_neutral_tpose.ply'
138
- gt_mesh = trimesh.load(mesh_path, process=False)
139
- pred_mesh = trimesh.load(mesh_path, process=False)
140
-
141
- img = np.transpose(img[0], (1, 2, 0))
142
- if normalize:
143
- # unnormalize the image before displaying
144
- mean = np.array(constants.IMG_NORM_MEAN, dtype=np.float32)
145
- std = np.array(constants.IMG_NORM_STD, dtype=np.float32)
146
- img = img * std + mean
147
- img = img * 255
148
- img = img.astype(np.uint8)
149
- color = np.array([0, 0, 0, 255])
150
- th = 0.5
151
-
152
- contact_labels_3d = contact_labels_3d[0, :]
153
- for vid, val in enumerate(contact_labels_3d):
154
- if val >= th:
155
- gt_mesh.visual.vertex_colors[vid] = color
156
-
157
- contact_labels_3d_pred = contact_labels_3d_pred[0, :]
158
- for vid, val in enumerate(contact_labels_3d_pred):
159
- if val >= th:
160
- pred_mesh.visual.vertex_colors[vid] = color
161
-
162
- gt_rend = create_scene(gt_mesh, img)
163
- pred_rend = create_scene(pred_mesh, img)
164
-
165
- sem_palette = [[220, 20, 60], [119, 11, 32], [0, 0, 142], [0, 0, 230], [106, 0, 228], [0, 60, 100], [0, 80, 100], [0, 0, 70], [0, 0, 192], [250, 170, 30], [100, 170, 30], [220, 220, 0], [175, 116, 175], [250, 0, 30], [165, 42, 42], [255, 77, 255], [0, 226, 252], [182, 182, 255], [0, 82, 0], [120, 166, 157], [110, 76, 0], [174, 57, 255], [199, 100, 0], [72, 0, 118], [255, 179, 240], [0, 125, 92], [209, 0, 151], [188, 208, 182], [0, 220, 176], [255, 99, 164], [92, 0, 73], [133, 129, 255], [78, 180, 255], [0, 228, 0], [174, 255, 243], [45, 89, 255], [134, 134, 103], [145, 148, 174], [255, 208, 186], [197, 226, 255], [171, 134, 1], [109, 63, 54], [207, 138, 255], [151, 0, 95], [9, 80, 61], [84, 105, 51], [74, 65, 105], [166, 196, 102], [208, 195, 210], [255, 109, 65], [0, 143, 149], [179, 0, 194], [209, 99, 106], [5, 121, 0], [227, 255, 205], [147, 186, 208], [153, 69, 1], [3, 95, 161], [163, 255, 0], [119, 0, 170], [0, 182, 199], [0, 165, 120], [183, 130, 88], [95, 32, 0], [130, 114, 135], [110, 129, 133], [166, 74, 118], [219, 142, 185], [79, 210, 114], [178, 90, 62], [65, 70, 15], [127, 167, 115], [59, 105, 106], [142, 108, 45], [196, 172, 0], [95, 54, 80], [128, 76, 255], [201, 57, 1], [246, 0, 122], [191, 162, 208], [255, 255, 128], [147, 211, 203], [150, 100, 100], [168, 171, 172], [146, 112, 198], [210, 170, 100], [92, 136, 89], [218, 88, 184], [241, 129, 0], [217, 17, 255], [124, 74, 181], [70, 70, 70], [255, 228, 255], [154, 208, 0], [193, 0, 92], [76, 91, 113], [255, 180, 195], [106, 154, 176], [230, 150, 140], [60, 143, 255], [128, 64, 128], [92, 82, 55], [254, 212, 124], [73, 77, 174], [255, 160, 98], [255, 255, 255], [104, 84, 109], [169, 164, 131], [225, 199, 255], [137, 54, 74], [135, 158, 223], [7, 246, 231], [107, 255, 200], [58, 41, 149], [183, 121, 142], [255, 73, 97], [107, 142, 35], [190, 153, 153], [146, 139, 141], [70, 130, 180], [134, 199, 156], [209, 226, 140], [96, 36, 108], [96, 96, 96], [64, 170, 64], [152, 251, 152], [208, 229, 228], [206, 186, 171], [152, 161, 64], [116, 112, 0], [0, 114, 143], [102, 102, 156], [250, 141, 255]]
166
- # part_palette = [(0,0,0), (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0), (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)]
167
- part_palette = [[0, 0, 0], [220, 20, 60], [119, 11, 32], [0, 0, 142], [0, 0, 230], [106, 0, 228], [0, 60, 100], [0, 80, 100], [0, 0, 70], [0, 0, 192], [250, 170, 30], [100, 170, 30], [220, 220, 0], [175, 116, 175], [250, 0, 30], [165, 42, 42], [255, 77, 255], [0, 226, 252], [182, 182, 255], [0, 82, 0], [120, 166, 157], [110, 76, 0], [174, 57, 255], [199, 100, 0], [72, 0, 118], [255, 179, 240]]
168
- hot_palette = [[0, 0, 0], [220, 20, 60], [119, 11, 32], [0, 0, 142], [0, 0, 230], [106, 0, 228], [0, 60, 100], [0, 80, 100], [0, 0, 70], [0, 0, 192], [250, 170, 30], [100, 170, 30], [220, 220, 0], [175, 116, 175], [250, 0, 30], [165, 42, 42], [255, 77, 255], [0, 226, 252]]
169
-
170
- sem_mask_gt = np.transpose(sem_mask_gt[0], (1, 2, 0))*255
171
- sem_mask_gt = sem_mask_gt.astype(np.uint8)
172
- sem_mask_pred = np.transpose(sem_mask_pred[0], (1, 2, 0))*255
173
- sem_mask_pred = sem_mask_pred.astype(np.uint8)
174
- part_mask_gt = np.transpose(part_mask_gt[0], (1, 2, 0))*255
175
- part_mask_gt = part_mask_gt.astype(np.uint8)
176
- part_mask_pred = np.transpose(part_mask_pred[0], (1, 2, 0))*255
177
- part_mask_pred = part_mask_pred.astype(np.uint8)
178
- contact_2d_gt_rgb = contact_2d_gt_rgb[0]*255
179
- contact_2d_gt_rgb = contact_2d_gt_rgb.astype(np.uint8)
180
- contact_2d_pred_rgb = contact_2d_pred_rgb[0]*255
181
- contact_2d_pred_rgb = contact_2d_pred_rgb.astype(np.uint8)
182
-
183
- sem_mask_rgb = unsplit(sem_mask_gt, sem_palette)
184
- sem_pred_rgb = unsplit(sem_mask_pred, sem_palette)
185
- part_mask_rgb = unsplit(part_mask_gt, part_palette)
186
- part_pred_rgb = unsplit(part_mask_pred, part_palette)
187
-
188
- sem_mask_rgb = sem_mask_rgb.astype(np.uint8)
189
- sem_pred_rgb = sem_pred_rgb.astype(np.uint8)
190
- part_mask_rgb = part_mask_rgb.astype(np.uint8)
191
- part_pred_rgb = part_pred_rgb.astype(np.uint8)
192
-
193
- sem_mask_rgb = pil_img.fromarray(sem_mask_rgb)
194
- sem_pred_rgb = pil_img.fromarray(sem_pred_rgb)
195
- part_mask_rgb = pil_img.fromarray(part_mask_rgb)
196
- part_pred_rgb = pil_img.fromarray(part_pred_rgb)
197
- contact_2d_gt_rgb = pil_img.fromarray(contact_2d_gt_rgb)
198
- contact_2d_pred_rgb = pil_img.fromarray(contact_2d_pred_rgb)
199
-
200
- tot_rend = pil_img.new('RGB', (3000, 2000))
201
- tot_rend.paste(gt_rend, (0, 0))
202
- tot_rend.paste(pred_rend, (0, 450))
203
- tot_rend.paste(sem_mask_rgb, (0, 900))
204
- tot_rend.paste(sem_pred_rgb, (400, 900))
205
- tot_rend.paste(part_mask_rgb, (0, 1300))
206
- tot_rend.paste(part_pred_rgb, (400, 1300))
207
- tot_rend.paste(contact_2d_gt_rgb, (0, 1700))
208
- tot_rend.paste(contact_2d_pred_rgb, (400, 1700))
209
- return tot_rend