Spaces:
Sleeping
Sleeping
cleared files
Browse files- inference.py +0 -192
- tester.py +0 -61
- train.py +0 -94
- train/__init__.py +0 -0
- train/base_trainer.py +0 -103
- train/trainer_step.py +0 -291
- utils/__init__.py +0 -0
- utils/cluster.py +0 -99
- utils/colorwheel.py +0 -22
- utils/config.py +0 -196
- utils/default_hparams.py +0 -45
- utils/diff_renderer.py +0 -287
- utils/get_cfg.py +0 -17
- utils/hrnet.py +0 -625
- utils/image_utils.py +0 -444
- utils/kp_utils.py +0 -1114
- utils/loss.py +0 -207
- utils/mesh_utils.py +0 -6
- utils/metrics.py +0 -106
- utils/smpl_uv.py +0 -167
- vis/__pycache__/visualize.cpython-37.pyc +0 -0
- vis/visualize.py +0 -209
inference.py
DELETED
@@ -1,192 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import os
|
3 |
-
import glob
|
4 |
-
import argparse
|
5 |
-
import numpy as np
|
6 |
-
import cv2
|
7 |
-
import PIL.Image as pil_img
|
8 |
-
from loguru import logger
|
9 |
-
import shutil
|
10 |
-
|
11 |
-
import trimesh
|
12 |
-
import pyrender
|
13 |
-
|
14 |
-
from models.deco import DECO
|
15 |
-
from common import constants
|
16 |
-
|
17 |
-
os.environ['PYOPENGL_PLATFORM'] = 'egl'
|
18 |
-
|
19 |
-
if torch.cuda.is_available():
|
20 |
-
device = torch.device('cuda')
|
21 |
-
else:
|
22 |
-
device = torch.device('cpu')
|
23 |
-
|
24 |
-
def initiate_model(args):
|
25 |
-
deco_model = DECO('hrnet', True, device)
|
26 |
-
|
27 |
-
logger.info(f'Loading weights from {args.model_path}')
|
28 |
-
checkpoint = torch.load(args.model_path)
|
29 |
-
deco_model.load_state_dict(checkpoint['deco'], strict=True)
|
30 |
-
|
31 |
-
deco_model.eval()
|
32 |
-
|
33 |
-
return deco_model
|
34 |
-
|
35 |
-
def render_image(scene, img_res, img=None, viewer=False):
|
36 |
-
'''
|
37 |
-
Render the given pyrender scene and return the image. Can also overlay the mesh on an image.
|
38 |
-
'''
|
39 |
-
if viewer:
|
40 |
-
pyrender.Viewer(scene, use_raymond_lighting=True)
|
41 |
-
return 0
|
42 |
-
else:
|
43 |
-
r = pyrender.OffscreenRenderer(viewport_width=img_res,
|
44 |
-
viewport_height=img_res,
|
45 |
-
point_size=1.0)
|
46 |
-
color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)
|
47 |
-
color = color.astype(np.float32) / 255.0
|
48 |
-
|
49 |
-
if img is not None:
|
50 |
-
valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
|
51 |
-
input_img = img.detach().cpu().numpy()
|
52 |
-
output_img = (color[:, :, :-1] * valid_mask +
|
53 |
-
(1 - valid_mask) * input_img)
|
54 |
-
else:
|
55 |
-
output_img = color
|
56 |
-
return output_img
|
57 |
-
|
58 |
-
def create_scene(mesh, img, focal_length=500, camera_center=250, img_res=500):
|
59 |
-
# Setup the scene
|
60 |
-
scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0, 1.0],
|
61 |
-
ambient_light=(0.3, 0.3, 0.3))
|
62 |
-
# add mesh for camera
|
63 |
-
camera_pose = np.eye(4)
|
64 |
-
camera_rotation = np.eye(3, 3)
|
65 |
-
camera_translation = np.array([0., 0, 2.5])
|
66 |
-
camera_pose[:3, :3] = camera_rotation
|
67 |
-
camera_pose[:3, 3] = camera_rotation @ camera_translation
|
68 |
-
pyrencamera = pyrender.camera.IntrinsicsCamera(
|
69 |
-
fx=focal_length, fy=focal_length,
|
70 |
-
cx=camera_center, cy=camera_center)
|
71 |
-
scene.add(pyrencamera, pose=camera_pose)
|
72 |
-
# create and add light
|
73 |
-
light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=1)
|
74 |
-
light_pose = np.eye(4)
|
75 |
-
for lp in [[1, 1, 1], [-1, 1, 1], [1, -1, 1], [-1, -1, 1]]:
|
76 |
-
light_pose[:3, 3] = mesh.vertices.mean(0) + np.array(lp)
|
77 |
-
# out_mesh.vertices.mean(0) + np.array(lp)
|
78 |
-
scene.add(light, pose=light_pose)
|
79 |
-
# add body mesh
|
80 |
-
material = pyrender.MetallicRoughnessMaterial(
|
81 |
-
metallicFactor=0.0,
|
82 |
-
alphaMode='OPAQUE',
|
83 |
-
baseColorFactor=(1.0, 1.0, 0.9, 1.0))
|
84 |
-
mesh_images = []
|
85 |
-
|
86 |
-
# resize input image to fit the mesh image height
|
87 |
-
img_height = img_res
|
88 |
-
img_width = int(img_height * img.shape[1] / img.shape[0])
|
89 |
-
img = cv2.resize(img, (img_width, img_height))
|
90 |
-
mesh_images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
91 |
-
|
92 |
-
for sideview_angle in [0, 90, 180, 270]:
|
93 |
-
out_mesh = mesh.copy()
|
94 |
-
rot = trimesh.transformations.rotation_matrix(
|
95 |
-
np.radians(sideview_angle), [0, 1, 0])
|
96 |
-
out_mesh.apply_transform(rot)
|
97 |
-
out_mesh = pyrender.Mesh.from_trimesh(
|
98 |
-
out_mesh,
|
99 |
-
material=material)
|
100 |
-
mesh_pose = np.eye(4)
|
101 |
-
scene.add(out_mesh, pose=mesh_pose, name='mesh')
|
102 |
-
output_img = render_image(scene, img_res)
|
103 |
-
output_img = pil_img.fromarray((output_img * 255).astype(np.uint8))
|
104 |
-
output_img = np.asarray(output_img)[:, :, :3]
|
105 |
-
mesh_images.append(output_img)
|
106 |
-
# delete the previous mesh
|
107 |
-
prev_mesh = scene.get_nodes(name='mesh').pop()
|
108 |
-
scene.remove_node(prev_mesh)
|
109 |
-
|
110 |
-
# show upside down view
|
111 |
-
for topview_angle in [90, 270]:
|
112 |
-
out_mesh = mesh.copy()
|
113 |
-
rot = trimesh.transformations.rotation_matrix(
|
114 |
-
np.radians(topview_angle), [1, 0, 0])
|
115 |
-
out_mesh.apply_transform(rot)
|
116 |
-
out_mesh = pyrender.Mesh.from_trimesh(
|
117 |
-
out_mesh,
|
118 |
-
material=material)
|
119 |
-
mesh_pose = np.eye(4)
|
120 |
-
scene.add(out_mesh, pose=mesh_pose, name='mesh')
|
121 |
-
output_img = render_image(scene, img_res)
|
122 |
-
output_img = pil_img.fromarray((output_img * 255).astype(np.uint8))
|
123 |
-
output_img = np.asarray(output_img)[:, :, :3]
|
124 |
-
mesh_images.append(output_img)
|
125 |
-
# delete the previous mesh
|
126 |
-
prev_mesh = scene.get_nodes(name='mesh').pop()
|
127 |
-
scene.remove_node(prev_mesh)
|
128 |
-
|
129 |
-
# stack images
|
130 |
-
IMG = np.hstack(mesh_images)
|
131 |
-
IMG = pil_img.fromarray(IMG)
|
132 |
-
IMG.thumbnail((3000, 3000))
|
133 |
-
return IMG
|
134 |
-
|
135 |
-
def main(args):
|
136 |
-
if os.path.isdir(args.img_src):
|
137 |
-
images = glob.iglob(args.img_src + '/*', recursive=True)
|
138 |
-
else:
|
139 |
-
images = [args.img_src]
|
140 |
-
|
141 |
-
deco_model = initiate_model(args)
|
142 |
-
|
143 |
-
smpl_path = os.path.join(constants.SMPL_MODEL_DIR, 'smpl_neutral_tpose.ply')
|
144 |
-
|
145 |
-
for img_name in images:
|
146 |
-
img = cv2.imread(img_name)
|
147 |
-
img = cv2.resize(img, (256, 256), cv2.INTER_CUBIC)
|
148 |
-
img = img.transpose(2,0,1)/255.0
|
149 |
-
img = img[np.newaxis,:,:,:]
|
150 |
-
img = torch.tensor(img, dtype = torch.float32).to(device)
|
151 |
-
|
152 |
-
cont, _, _ = deco_model(img)
|
153 |
-
cont = cont.detach().cpu().numpy().squeeze()
|
154 |
-
cont_smpl = []
|
155 |
-
for indx, i in enumerate(cont):
|
156 |
-
if i >= 0.5:
|
157 |
-
cont_smpl.append(indx)
|
158 |
-
|
159 |
-
img = img.detach().cpu().numpy()
|
160 |
-
img = np.transpose(img[0], (1, 2, 0))
|
161 |
-
img = img * 255
|
162 |
-
img = img.astype(np.uint8)
|
163 |
-
|
164 |
-
contact_smpl = np.zeros((1, 1, 6890))
|
165 |
-
contact_smpl[0][0][cont_smpl] = 1
|
166 |
-
|
167 |
-
body_model_smpl = trimesh.load(smpl_path, process=False)
|
168 |
-
for vert in range(body_model_smpl.visual.vertex_colors.shape[0]):
|
169 |
-
body_model_smpl.visual.vertex_colors[vert] = args.mesh_colour
|
170 |
-
body_model_smpl.visual.vertex_colors[cont_smpl] = args.annot_colour
|
171 |
-
|
172 |
-
rend = create_scene(body_model_smpl, img)
|
173 |
-
os.makedirs(os.path.join(args.out_dir, 'Renders'), exist_ok=True)
|
174 |
-
rend.save(os.path.join(args.out_dir, 'Renders', os.path.basename(img_name).split('.')[0] + '.png'))
|
175 |
-
|
176 |
-
out_dir = os.path.join(args.out_dir, 'Preds', os.path.basename(img_name).split('.')[0])
|
177 |
-
os.makedirs(out_dir, exist_ok=True)
|
178 |
-
|
179 |
-
logger.info(f'Saving mesh to {out_dir}')
|
180 |
-
shutil.copyfile(img_name, os.path.join(out_dir, os.path.basename(img_name)))
|
181 |
-
body_model_smpl.export(os.path.join(out_dir, 'pred.obj'))
|
182 |
-
|
183 |
-
if __name__=='__main__':
|
184 |
-
parser = argparse.ArgumentParser()
|
185 |
-
parser.add_argument('--img_src', help='Source of image(s). Can be file or directory', default='./demo_out', type=str)
|
186 |
-
parser.add_argument('--out_dir', help='Where to store images', default='./demo_out', type=str)
|
187 |
-
parser.add_argument('--model_path', help='Path to best model weights', default='./checkpoints/Release_Checkpoint/deco_best.pth', type=str)
|
188 |
-
parser.add_argument('--mesh_colour', help='Colour of the mesh', nargs='+', type=int, default=[130, 130, 130, 255])
|
189 |
-
parser.add_argument('--annot_colour', help='Colour of the mesh', nargs='+', type=int, default=[0, 255, 0, 255])
|
190 |
-
args = parser.parse_args()
|
191 |
-
|
192 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tester.py
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch.utils.data import DataLoader
|
3 |
-
from loguru import logger
|
4 |
-
|
5 |
-
from train.trainer_step import TrainStepper
|
6 |
-
from train.base_trainer import evaluator
|
7 |
-
from data.base_dataset import BaseDataset
|
8 |
-
from models.deco import DECO
|
9 |
-
from utils.config import parse_args, run_grid_search_experiments
|
10 |
-
|
11 |
-
def test(hparams):
|
12 |
-
deco_model = DECO(hparams.TRAINING.ENCODER, hparams.TRAINING.CONTEXT, device)
|
13 |
-
pytorch_total_params = sum(p.numel() for p in deco_model.parameters() if p.requires_grad)
|
14 |
-
print('Total number of trainable parameters: ', pytorch_total_params)
|
15 |
-
|
16 |
-
solver = TrainStepper(deco_model, hparams.TRAINING.CONTEXT, hparams.OPTIMIZER.LR, hparams.TRAINING.LOSS_WEIGHTS, hparams.TRAINING.PAL_LOSS_WEIGHTS, device)
|
17 |
-
|
18 |
-
logger.info(f'Loading weights from {hparams.TRAINING.BEST_MODEL_PATH}')
|
19 |
-
_, _ = solver.load(hparams.TRAINING.BEST_MODEL_PATH)
|
20 |
-
|
21 |
-
# Run testing
|
22 |
-
for test_loader in val_loaders:
|
23 |
-
dataset_name = test_loader.dataset.dataset
|
24 |
-
test_dict, total_time = evaluator(test_loader, solver, hparams, 0, dataset_name, return_dict=True)
|
25 |
-
|
26 |
-
print('Test Contact Precision: ', test_dict['cont_precision'])
|
27 |
-
print('Test Contact Recall: ', test_dict['cont_recall'])
|
28 |
-
print('Test Contact F1 Score: ', test_dict['cont_f1'])
|
29 |
-
print('Test Contact FP Geo. Error: ', test_dict['fp_geo_err'])
|
30 |
-
print('Test Contact FN Geo. Error: ', test_dict['fn_geo_err'])
|
31 |
-
if hparams.TRAINING.CONTEXT:
|
32 |
-
print('Test Contact Semantic Segmentation IoU: ', test_dict['sem_iou'])
|
33 |
-
print('Test Contact Part Segmentation IoU: ', test_dict['part_iou'])
|
34 |
-
print('\nTime taken per image for evaluation: ', total_time)
|
35 |
-
print('-'*50)
|
36 |
-
|
37 |
-
if __name__ == '__main__':
|
38 |
-
args = parse_args()
|
39 |
-
hparams = run_grid_search_experiments(
|
40 |
-
args,
|
41 |
-
script='tester.py',
|
42 |
-
change_wt_name=False
|
43 |
-
)
|
44 |
-
|
45 |
-
if torch.cuda.is_available():
|
46 |
-
device = torch.device('cuda')
|
47 |
-
else:
|
48 |
-
device = torch.device('cpu')
|
49 |
-
|
50 |
-
val_datasets = []
|
51 |
-
for ds in hparams.VALIDATION.DATASETS:
|
52 |
-
if ds in ['rich', 'prox']:
|
53 |
-
val_datasets.append(BaseDataset(ds, 'val', model_type='smplx', normalize=hparams.DATASET.NORMALIZE_IMAGES))
|
54 |
-
elif ds in ['damon']:
|
55 |
-
val_datasets.append(BaseDataset(ds, 'val', model_type='smpl', normalize=hparams.DATASET.NORMALIZE_IMAGES))
|
56 |
-
else:
|
57 |
-
raise ValueError('Dataset not supported')
|
58 |
-
|
59 |
-
val_loaders = [DataLoader(val_dataset, batch_size=hparams.DATASET.BATCH_SIZE, shuffle=False, num_workers=hparams.DATASET.NUM_WORKERS) for val_dataset in val_datasets]
|
60 |
-
|
61 |
-
test(hparams)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.py
DELETED
@@ -1,94 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch.utils.data import DataLoader
|
3 |
-
import os
|
4 |
-
|
5 |
-
from train.trainer_step import TrainStepper
|
6 |
-
from train.base_trainer import trainer, evaluator
|
7 |
-
from data.base_dataset import BaseDataset
|
8 |
-
from data.mixed_dataset import MixedDataset
|
9 |
-
from models.deco import DECO
|
10 |
-
from utils.config import parse_args, run_grid_search_experiments
|
11 |
-
|
12 |
-
def train(hparams):
|
13 |
-
deco_model = DECO(hparams.TRAINING.ENCODER, hparams.TRAINING.CONTEXT, device)
|
14 |
-
|
15 |
-
solver = TrainStepper(deco_model, hparams.TRAINING.CONTEXT, hparams.OPTIMIZER.LR, hparams.TRAINING.LOSS_WEIGHTS, hparams.TRAINING.PAL_LOSS_WEIGHTS, device)
|
16 |
-
|
17 |
-
vb_f1 = 0
|
18 |
-
start_ep = 0
|
19 |
-
num = 0
|
20 |
-
k = True
|
21 |
-
latest_model_path = hparams.TRAINING.BEST_MODEL_PATH.replace('best', 'latest')
|
22 |
-
if os.path.exists(latest_model_path):
|
23 |
-
_, vb_f1 = solver.load(hparams.TRAINING.BEST_MODEL_PATH)
|
24 |
-
start_ep, _ = solver.load(latest_model_path)
|
25 |
-
|
26 |
-
for epoch in range(start_ep+1, hparams.TRAINING.NUM_EPOCHS + 1):
|
27 |
-
# Train one epoch
|
28 |
-
trainer(epoch, train_loader, solver, hparams)
|
29 |
-
# Run evaluation
|
30 |
-
vc_f1 = None
|
31 |
-
for val_loader in val_loaders:
|
32 |
-
dataset_name = val_loader.dataset.dataset
|
33 |
-
vc_f1_ds = evaluator(val_loader, solver, hparams, epoch, dataset_name, normalize=hparams.DATASET.NORMALIZE_IMAGES)
|
34 |
-
if dataset_name == hparams.VALIDATION.MAIN_DATASET:
|
35 |
-
vc_f1 = vc_f1_ds
|
36 |
-
if vc_f1 is None:
|
37 |
-
raise ValueError('Main dataset not found in validation datasets')
|
38 |
-
|
39 |
-
print('Learning rate: ', solver.lr)
|
40 |
-
|
41 |
-
print('---------------------------------------------')
|
42 |
-
print('---------------------------------------------')
|
43 |
-
|
44 |
-
solver.save(epoch, vc_f1, latest_model_path)
|
45 |
-
|
46 |
-
if epoch % hparams.TRAINING.CHECKPOINT_EPOCHS == 0:
|
47 |
-
inter_model_path = latest_model_path.replace('latest', 'epoch_'+str(epoch).zfill(3))
|
48 |
-
solver.save(epoch, vc_f1, inter_model_path)
|
49 |
-
|
50 |
-
if vc_f1 < vb_f1:
|
51 |
-
num += 1
|
52 |
-
print('Not Saving model: Best Val F1 = ', vb_f1, ' Current Val F1 = ', vc_f1)
|
53 |
-
else:
|
54 |
-
num = 0
|
55 |
-
vb_f1 = vc_f1
|
56 |
-
print('Saving model...')
|
57 |
-
solver.save(epoch, vb_f1, hparams.TRAINING.BEST_MODEL_PATH)
|
58 |
-
|
59 |
-
if num >= hparams.OPTIMIZER.NUM_UPDATE_LR: solver.update_lr()
|
60 |
-
if num >= hparams.TRAINING.NUM_EARLY_STOP:
|
61 |
-
print('Early Stop')
|
62 |
-
k = False
|
63 |
-
|
64 |
-
if k: continue
|
65 |
-
else: break
|
66 |
-
|
67 |
-
|
68 |
-
if __name__ == '__main__':
|
69 |
-
args = parse_args()
|
70 |
-
hparams = run_grid_search_experiments(
|
71 |
-
args,
|
72 |
-
script='train.py',
|
73 |
-
)
|
74 |
-
|
75 |
-
if torch.cuda.is_available():
|
76 |
-
device = torch.device('cuda')
|
77 |
-
else:
|
78 |
-
device = torch.device('cpu')
|
79 |
-
|
80 |
-
train_dataset = MixedDataset(hparams.TRAINING.DATASETS, 'train', dataset_mix_pdf=hparams.TRAINING.DATASET_MIX_PDF, normalize=hparams.DATASET.NORMALIZE_IMAGES)
|
81 |
-
|
82 |
-
val_datasets = []
|
83 |
-
for ds in hparams.VALIDATION.DATASETS:
|
84 |
-
if ds in ['rich', 'prox']:
|
85 |
-
val_datasets.append(BaseDataset(ds, 'val', model_type='smplx', normalize=hparams.DATASET.NORMALIZE_IMAGES))
|
86 |
-
elif ds in ['damon']:
|
87 |
-
val_datasets.append(BaseDataset(ds, 'val', model_type='smpl', normalize=hparams.DATASET.NORMALIZE_IMAGES))
|
88 |
-
else:
|
89 |
-
raise ValueError('Dataset not supported')
|
90 |
-
|
91 |
-
train_loader = DataLoader(train_dataset, hparams.DATASET.BATCH_SIZE, shuffle=True, num_workers=hparams.DATASET.NUM_WORKERS)
|
92 |
-
val_loaders = [DataLoader(val_dataset, batch_size=hparams.DATASET.BATCH_SIZE, shuffle=False, num_workers=hparams.DATASET.NUM_WORKERS) for val_dataset in val_datasets]
|
93 |
-
|
94 |
-
train(hparams)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train/__init__.py
DELETED
File without changes
|
train/base_trainer.py
DELETED
@@ -1,103 +0,0 @@
|
|
1 |
-
from tqdm import tqdm
|
2 |
-
from utils.metrics import metric, precision_recall_f1score, det_error_metric
|
3 |
-
import torch
|
4 |
-
import numpy as np
|
5 |
-
from vis.visualize import gen_render
|
6 |
-
|
7 |
-
|
8 |
-
def trainer(epoch, train_loader, solver, hparams, compute_metrics=False):
|
9 |
-
|
10 |
-
total_epochs = hparams.TRAINING.NUM_EPOCHS
|
11 |
-
print('Training Epoch {}/{}'.format(epoch, total_epochs))
|
12 |
-
|
13 |
-
length = len(train_loader)
|
14 |
-
iterator = tqdm(enumerate(train_loader), total=length, leave=False, desc=f'Training Epoch: {epoch}/{total_epochs}')
|
15 |
-
for step, batch in iterator:
|
16 |
-
losses, output = solver.optimize(batch)
|
17 |
-
return losses, output
|
18 |
-
|
19 |
-
@torch.no_grad()
|
20 |
-
def evaluator(val_loader, solver, hparams, epoch=0, dataset_name='Unknown', normalize=True, return_dict=False):
|
21 |
-
total_epochs = hparams.TRAINING.NUM_EPOCHS
|
22 |
-
|
23 |
-
batch_size = val_loader.batch_size
|
24 |
-
dataset_size = len(val_loader.dataset)
|
25 |
-
print(f'Dataset size: {dataset_size}')
|
26 |
-
|
27 |
-
val_epoch_cont_pre = np.zeros(dataset_size)
|
28 |
-
val_epoch_cont_rec = np.zeros(dataset_size)
|
29 |
-
val_epoch_cont_f1 = np.zeros(dataset_size)
|
30 |
-
val_epoch_fp_geo_err = np.zeros(dataset_size)
|
31 |
-
val_epoch_fn_geo_err = np.zeros(dataset_size)
|
32 |
-
if hparams.TRAINING.CONTEXT:
|
33 |
-
val_epoch_sem_iou = np.zeros(dataset_size)
|
34 |
-
val_epoch_part_iou = np.zeros(dataset_size)
|
35 |
-
|
36 |
-
val_epoch_cont_loss = np.zeros(dataset_size)
|
37 |
-
|
38 |
-
total_time = 0
|
39 |
-
|
40 |
-
rend_images = []
|
41 |
-
|
42 |
-
eval_dict = {}
|
43 |
-
|
44 |
-
length = len(val_loader)
|
45 |
-
iterator = tqdm(enumerate(val_loader), total=length, leave=False, desc=f'Evaluating {dataset_name.capitalize()} Epoch: {epoch}/{total_epochs}')
|
46 |
-
for step, batch in iterator:
|
47 |
-
curr_batch_size = batch['img'].shape[0]
|
48 |
-
losses, output, time_taken = solver.evaluate(batch)
|
49 |
-
|
50 |
-
val_epoch_cont_loss[step * batch_size:step * batch_size + curr_batch_size] = losses['cont_loss'].cpu().numpy()
|
51 |
-
|
52 |
-
# compute metrics
|
53 |
-
contact_labels_3d = output['contact_labels_3d_gt']
|
54 |
-
has_contact_3d = output['has_contact_3d']
|
55 |
-
# check if any value in has_contact_3d tensor is 0
|
56 |
-
assert torch.any(has_contact_3d == 0) == False, 'has_contact_3d tensor has 0 values'
|
57 |
-
|
58 |
-
contact_labels_3d_pred = output['contact_labels_3d_pred']
|
59 |
-
if hparams.TRAINING.CONTEXT:
|
60 |
-
sem_mask_gt = output['sem_mask_gt']
|
61 |
-
sem_seg_pred = output['sem_mask_pred']
|
62 |
-
part_mask_gt = output['part_mask_gt']
|
63 |
-
part_seg_pred = output['part_mask_pred']
|
64 |
-
|
65 |
-
cont_pre, cont_rec, cont_f1 = precision_recall_f1score(contact_labels_3d, contact_labels_3d_pred)
|
66 |
-
fp_geo_err, fn_geo_err = det_error_metric(contact_labels_3d_pred, contact_labels_3d)
|
67 |
-
if hparams.TRAINING.CONTEXT:
|
68 |
-
sem_iou = metric(sem_mask_gt, sem_seg_pred)
|
69 |
-
part_iou = metric(part_mask_gt, part_seg_pred)
|
70 |
-
|
71 |
-
val_epoch_cont_pre[step * batch_size:step * batch_size + curr_batch_size] = cont_pre.cpu().numpy()
|
72 |
-
val_epoch_cont_rec[step * batch_size:step * batch_size + curr_batch_size] = cont_rec.cpu().numpy()
|
73 |
-
val_epoch_cont_f1[step * batch_size:step * batch_size + curr_batch_size] = cont_f1.cpu().numpy()
|
74 |
-
val_epoch_fp_geo_err[step * batch_size:step * batch_size + curr_batch_size] = fp_geo_err.cpu().numpy()
|
75 |
-
val_epoch_fn_geo_err[step * batch_size:step * batch_size + curr_batch_size] = fn_geo_err.cpu().numpy()
|
76 |
-
if hparams.TRAINING.CONTEXT:
|
77 |
-
val_epoch_sem_iou[step * batch_size:step * batch_size + curr_batch_size] = sem_iou.cpu().numpy()
|
78 |
-
val_epoch_part_iou[step * batch_size:step * batch_size + curr_batch_size] = part_iou.cpu().numpy()
|
79 |
-
|
80 |
-
total_time += time_taken
|
81 |
-
|
82 |
-
# logging every summary_steps steps
|
83 |
-
if step % hparams.VALIDATION.SUMMARY_STEPS == 0:
|
84 |
-
if hparams.TRAINING.CONTEXT:
|
85 |
-
rend = gen_render(output, normalize)
|
86 |
-
rend_images.append(rend)
|
87 |
-
|
88 |
-
eval_dict['cont_precision'] = np.sum(val_epoch_cont_pre) / dataset_size
|
89 |
-
eval_dict['cont_recall'] = np.sum(val_epoch_cont_rec) / dataset_size
|
90 |
-
eval_dict['cont_f1'] = np.sum(val_epoch_cont_f1) / dataset_size
|
91 |
-
eval_dict['fp_geo_err'] = np.sum(val_epoch_fp_geo_err) / dataset_size
|
92 |
-
eval_dict['fn_geo_err'] = np.sum(val_epoch_fn_geo_err) / dataset_size
|
93 |
-
if hparams.TRAINING.CONTEXT:
|
94 |
-
eval_dict['sem_iou'] = np.sum(val_epoch_sem_iou) / dataset_size
|
95 |
-
eval_dict['part_iou'] = np.sum(val_epoch_part_iou) / dataset_size
|
96 |
-
eval_dict['images'] = rend_images
|
97 |
-
|
98 |
-
total_time /= dataset_size
|
99 |
-
|
100 |
-
val_epoch_cont_loss = np.sum(val_epoch_cont_loss) / dataset_size
|
101 |
-
if return_dict:
|
102 |
-
return eval_dict, total_time
|
103 |
-
return eval_dict['cont_f1']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train/trainer_step.py
DELETED
@@ -1,291 +0,0 @@
|
|
1 |
-
from utils.loss import sem_loss_function, class_loss_function, pixel_anchoring_function
|
2 |
-
import torch
|
3 |
-
import os
|
4 |
-
import time
|
5 |
-
|
6 |
-
|
7 |
-
class TrainStepper():
|
8 |
-
def __init__(self, deco_model, context, learning_rate, loss_weight, pal_loss_weight, device):
|
9 |
-
self.device = device
|
10 |
-
|
11 |
-
self.model = deco_model
|
12 |
-
self.context = context
|
13 |
-
|
14 |
-
if self.context:
|
15 |
-
self.optimizer_sem = torch.optim.Adam(params=list(self.model.encoder_sem.parameters()) + list(self.model.decoder_sem.parameters()),
|
16 |
-
lr=learning_rate, weight_decay=0.0001)
|
17 |
-
self.optimizer_part = torch.optim.Adam(
|
18 |
-
params=list(self.model.encoder_part.parameters()) + list(self.model.decoder_part.parameters()), lr=learning_rate,
|
19 |
-
weight_decay=0.0001)
|
20 |
-
self.optimizer_contact = torch.optim.Adam(
|
21 |
-
params=list(self.model.encoder_sem.parameters()) + list(self.model.encoder_part.parameters()) + list(
|
22 |
-
self.model.cross_att.parameters()) + list(self.model.classif.parameters()), lr=learning_rate, weight_decay=0.0001)
|
23 |
-
|
24 |
-
if self.context: self.sem_loss = sem_loss_function().to(device)
|
25 |
-
self.class_loss = class_loss_function().to(device)
|
26 |
-
self.pixel_anchoring_loss_smplx = pixel_anchoring_function(model_type='smplx').to(device)
|
27 |
-
self.pixel_anchoring_loss_smpl = pixel_anchoring_function(model_type='smpl').to(device)
|
28 |
-
self.lr = learning_rate
|
29 |
-
self.loss_weight = loss_weight
|
30 |
-
self.pal_loss_weight = pal_loss_weight
|
31 |
-
|
32 |
-
def optimize(self, batch):
|
33 |
-
self.model.train()
|
34 |
-
|
35 |
-
img_paths = batch['img_path']
|
36 |
-
img = batch['img'].to(self.device)
|
37 |
-
|
38 |
-
img_scale_factor = batch['img_scale_factor'].to(self.device)
|
39 |
-
|
40 |
-
pose = batch['pose'].to(self.device)
|
41 |
-
betas = batch['betas'].to(self.device)
|
42 |
-
transl = batch['transl'].to(self.device)
|
43 |
-
has_smpl = batch['has_smpl'].to(self.device)
|
44 |
-
is_smplx = batch['is_smplx'].to(self.device)
|
45 |
-
|
46 |
-
cam_k = batch['cam_k'].to(self.device)
|
47 |
-
|
48 |
-
gt_contact_labels_3d = batch['contact_label_3d'].to(self.device)
|
49 |
-
has_contact_3d = batch['has_contact_3d'].to(self.device)
|
50 |
-
|
51 |
-
if self.context:
|
52 |
-
sem_mask_gt = batch['sem_mask'].to(self.device)
|
53 |
-
part_mask_gt = batch['part_mask'].to(self.device)
|
54 |
-
|
55 |
-
polygon_contact_2d = batch['polygon_contact_2d'].to(self.device)
|
56 |
-
has_polygon_contact_2d = batch['has_polygon_contact_2d'].to(self.device)
|
57 |
-
|
58 |
-
# Forward pass
|
59 |
-
if self.context:
|
60 |
-
cont, sem_mask_pred, part_mask_pred = self.model(img)
|
61 |
-
else:
|
62 |
-
cont = self.model(img)
|
63 |
-
|
64 |
-
if self.context:
|
65 |
-
loss_sem = self.sem_loss(sem_mask_gt, sem_mask_pred)
|
66 |
-
loss_part = self.sem_loss(part_mask_gt, part_mask_pred)
|
67 |
-
valid_contact_3d = has_contact_3d
|
68 |
-
loss_cont = self.class_loss(gt_contact_labels_3d, cont, valid_contact_3d)
|
69 |
-
valid_polygon_contact_2d = has_polygon_contact_2d
|
70 |
-
|
71 |
-
if self.pal_loss_weight > 0 and (is_smplx == 0).sum() > 0:
|
72 |
-
smpl_body_params = {'pose': pose[is_smplx == 0], 'betas': betas[is_smplx == 0],
|
73 |
-
'transl': transl[is_smplx == 0],
|
74 |
-
'has_smpl': has_smpl[is_smplx == 0]}
|
75 |
-
loss_pix_anchoring_smpl, contact_2d_pred_rgb_smpl, _ = self.pixel_anchoring_loss_smpl(cont[is_smplx == 0],
|
76 |
-
smpl_body_params,
|
77 |
-
cam_k[is_smplx == 0],
|
78 |
-
img_scale_factor[
|
79 |
-
is_smplx == 0],
|
80 |
-
polygon_contact_2d[
|
81 |
-
is_smplx == 0],
|
82 |
-
valid_polygon_contact_2d[
|
83 |
-
is_smplx == 0])
|
84 |
-
# weigh the smpl loss based on the number of smpl sample
|
85 |
-
loss_pix_anchoring = loss_pix_anchoring_smpl * (is_smplx == 0).sum() / len(is_smplx)
|
86 |
-
contact_2d_pred_rgb = contact_2d_pred_rgb_smpl
|
87 |
-
else:
|
88 |
-
loss_pix_anchoring = 0
|
89 |
-
contact_2d_pred_rgb = torch.zeros_like(polygon_contact_2d)
|
90 |
-
|
91 |
-
if self.context: loss = loss_sem + loss_part + self.loss_weight * loss_cont + self.pal_loss_weight * loss_pix_anchoring
|
92 |
-
else: loss = self.loss_weight * loss_cont + self.pal_loss_weight * loss_pix_anchoring
|
93 |
-
|
94 |
-
if self.context:
|
95 |
-
self.optimizer_sem.zero_grad()
|
96 |
-
self.optimizer_part.zero_grad()
|
97 |
-
self.optimizer_contact.zero_grad()
|
98 |
-
|
99 |
-
loss.backward()
|
100 |
-
|
101 |
-
if self.context:
|
102 |
-
self.optimizer_sem.step()
|
103 |
-
self.optimizer_part.step()
|
104 |
-
self.optimizer_contact.step()
|
105 |
-
|
106 |
-
if self.context:
|
107 |
-
losses = {'sem_loss': loss_sem,
|
108 |
-
'part_loss': loss_part,
|
109 |
-
'cont_loss': loss_cont,
|
110 |
-
'pal_loss': loss_pix_anchoring,
|
111 |
-
'total_loss': loss}
|
112 |
-
else:
|
113 |
-
losses = {'cont_loss': loss_cont,
|
114 |
-
'pal_loss': loss_pix_anchoring,
|
115 |
-
'total_loss': loss}
|
116 |
-
|
117 |
-
if self.context:
|
118 |
-
output = {
|
119 |
-
'img': img,
|
120 |
-
'sem_mask_gt': sem_mask_gt,
|
121 |
-
'sem_mask_pred': sem_mask_pred,
|
122 |
-
'part_mask_gt': part_mask_gt,
|
123 |
-
'part_mask_pred': part_mask_pred,
|
124 |
-
'has_contact_2d': has_polygon_contact_2d,
|
125 |
-
'contact_2d_gt': polygon_contact_2d,
|
126 |
-
'contact_2d_pred_rgb': contact_2d_pred_rgb,
|
127 |
-
'has_contact_3d': has_contact_3d,
|
128 |
-
'contact_labels_3d_gt': gt_contact_labels_3d,
|
129 |
-
'contact_labels_3d_pred': cont}
|
130 |
-
else:
|
131 |
-
output = {
|
132 |
-
'img': img,
|
133 |
-
'has_contact_2d': has_polygon_contact_2d,
|
134 |
-
'contact_2d_gt': polygon_contact_2d,
|
135 |
-
'contact_2d_pred_rgb': contact_2d_pred_rgb,
|
136 |
-
'has_contact_3d': has_contact_3d,
|
137 |
-
'contact_labels_3d_gt': gt_contact_labels_3d,
|
138 |
-
'contact_labels_3d_pred': cont}
|
139 |
-
|
140 |
-
return losses, output
|
141 |
-
|
142 |
-
@torch.no_grad()
|
143 |
-
def evaluate(self, batch):
|
144 |
-
self.model.eval()
|
145 |
-
|
146 |
-
img_paths = batch['img_path']
|
147 |
-
img = batch['img'].to(self.device)
|
148 |
-
|
149 |
-
img_scale_factor = batch['img_scale_factor'].to(self.device)
|
150 |
-
|
151 |
-
pose = batch['pose'].to(self.device)
|
152 |
-
betas = batch['betas'].to(self.device)
|
153 |
-
transl = batch['transl'].to(self.device)
|
154 |
-
has_smpl = batch['has_smpl'].to(self.device)
|
155 |
-
is_smplx = batch['is_smplx'].to(self.device)
|
156 |
-
|
157 |
-
cam_k = batch['cam_k'].to(self.device)
|
158 |
-
|
159 |
-
gt_contact_labels_3d = batch['contact_label_3d'].to(self.device)
|
160 |
-
has_contact_3d = batch['has_contact_3d'].to(self.device)
|
161 |
-
|
162 |
-
if self.context:
|
163 |
-
sem_mask_gt = batch['sem_mask'].to(self.device)
|
164 |
-
part_mask_gt = batch['part_mask'].to(self.device)
|
165 |
-
|
166 |
-
polygon_contact_2d = batch['polygon_contact_2d'].to(self.device)
|
167 |
-
has_polygon_contact_2d = batch['has_polygon_contact_2d'].to(self.device)
|
168 |
-
|
169 |
-
# Forward pass
|
170 |
-
initial_time = time.time()
|
171 |
-
if self.context: cont, sem_mask_pred, part_mask_pred = self.model(img)
|
172 |
-
else: cont = self.model(img)
|
173 |
-
time_taken = time.time() - initial_time
|
174 |
-
|
175 |
-
if self.context:
|
176 |
-
loss_sem = self.sem_loss(sem_mask_gt, sem_mask_pred)
|
177 |
-
loss_part = self.sem_loss(part_mask_gt, part_mask_pred)
|
178 |
-
valid_contact_3d = has_contact_3d
|
179 |
-
loss_cont = self.class_loss(gt_contact_labels_3d, cont, valid_contact_3d)
|
180 |
-
valid_polygon_contact_2d = has_polygon_contact_2d
|
181 |
-
|
182 |
-
if self.pal_loss_weight > 0 and (is_smplx == 0).sum() > 0: # PAL loss only on 2D contacts in HOT which only has SMPL
|
183 |
-
smpl_body_params = {'pose': pose[is_smplx == 0], 'betas': betas[is_smplx == 0], 'transl': transl[is_smplx == 0],
|
184 |
-
'has_smpl': has_smpl[is_smplx == 0]}
|
185 |
-
loss_pix_anchoring_smpl, contact_2d_pred_rgb_smpl, _ = self.pixel_anchoring_loss_smpl(cont[is_smplx == 0],
|
186 |
-
smpl_body_params,
|
187 |
-
cam_k[is_smplx == 0],
|
188 |
-
img_scale_factor[
|
189 |
-
is_smplx == 0],
|
190 |
-
polygon_contact_2d[
|
191 |
-
is_smplx == 0],
|
192 |
-
valid_polygon_contact_2d[
|
193 |
-
is_smplx == 0])
|
194 |
-
# weight the smpl loss based on the number of smpl samples
|
195 |
-
contact_2d_pred_rgb = contact_2d_pred_rgb_smpl
|
196 |
-
loss_pix_anchoring = loss_pix_anchoring_smpl * (is_smplx == 0).sum() / len(is_smplx)
|
197 |
-
else:
|
198 |
-
loss_pix_anchoring = 0
|
199 |
-
contact_2d_pred_rgb = torch.zeros_like(polygon_contact_2d)
|
200 |
-
|
201 |
-
if self.context: loss = loss_sem + loss_part + self.loss_weight * loss_cont + self.pal_loss_weight * loss_pix_anchoring
|
202 |
-
else: loss = self.loss_weight * loss_cont + self.pal_loss_weight * loss_pix_anchoring
|
203 |
-
|
204 |
-
if self.context:
|
205 |
-
losses = {'sem_loss': loss_sem,
|
206 |
-
'part_loss': loss_part,
|
207 |
-
'cont_loss': loss_cont,
|
208 |
-
'pal_loss': loss_pix_anchoring,
|
209 |
-
'total_loss': loss}
|
210 |
-
else:
|
211 |
-
losses = {'cont_loss': loss_cont,
|
212 |
-
'pal_loss': loss_pix_anchoring,
|
213 |
-
'total_loss': loss}
|
214 |
-
|
215 |
-
if self.context:
|
216 |
-
output = {
|
217 |
-
'img': img,
|
218 |
-
'sem_mask_gt': sem_mask_gt,
|
219 |
-
'sem_mask_pred': sem_mask_pred,
|
220 |
-
'part_mask_gt': part_mask_gt,
|
221 |
-
'part_mask_pred': part_mask_pred,
|
222 |
-
'has_contact_2d': has_polygon_contact_2d,
|
223 |
-
'contact_2d_gt': polygon_contact_2d,
|
224 |
-
'contact_2d_pred_rgb': contact_2d_pred_rgb,
|
225 |
-
'has_contact_3d': has_contact_3d,
|
226 |
-
'contact_labels_3d_gt': gt_contact_labels_3d,
|
227 |
-
'contact_labels_3d_pred': cont}
|
228 |
-
else:
|
229 |
-
output = {
|
230 |
-
'img': img,
|
231 |
-
'has_contact_2d': has_polygon_contact_2d,
|
232 |
-
'contact_2d_gt': polygon_contact_2d,
|
233 |
-
'contact_2d_pred_rgb': contact_2d_pred_rgb,
|
234 |
-
'has_contact_3d': has_contact_3d,
|
235 |
-
'contact_labels_3d_gt': gt_contact_labels_3d,
|
236 |
-
'contact_labels_3d_pred': cont}
|
237 |
-
|
238 |
-
return losses, output, time_taken
|
239 |
-
|
240 |
-
def save(self, ep, f1, model_path):
|
241 |
-
# create model directory if it does not exist
|
242 |
-
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
243 |
-
if self.context:
|
244 |
-
torch.save({
|
245 |
-
'epoch': ep,
|
246 |
-
'deco': self.model.state_dict(),
|
247 |
-
'f1': f1,
|
248 |
-
'sem_optim': self.optimizer_sem.state_dict(),
|
249 |
-
'part_optim': self.optimizer_part.state_dict(),
|
250 |
-
'contact_optim': self.optimizer_contact.state_dict()
|
251 |
-
},
|
252 |
-
model_path)
|
253 |
-
else:
|
254 |
-
torch.save({
|
255 |
-
'epoch': ep,
|
256 |
-
'deco': self.model.state_dict(),
|
257 |
-
'f1': f1,
|
258 |
-
'sem_optim': self.optimizer_sem.state_dict(),
|
259 |
-
'part_optim': self.optimizer_part.state_dict(),
|
260 |
-
'contact_optim': self.optimizer_contact.state_dict()
|
261 |
-
},
|
262 |
-
model_path)
|
263 |
-
|
264 |
-
def load(self, model_path):
|
265 |
-
print(f'~~~ Loading existing checkpoint from {model_path} ~~~')
|
266 |
-
checkpoint = torch.load(model_path)
|
267 |
-
self.model.load_state_dict(checkpoint['deco'], strict=True)
|
268 |
-
|
269 |
-
if self.context:
|
270 |
-
self.optimizer_sem.load_state_dict(checkpoint['sem_optim'])
|
271 |
-
self.optimizer_part.load_state_dict(checkpoint['part_optim'])
|
272 |
-
self.optimizer_contact.load_state_dict(checkpoint['contact_optim'])
|
273 |
-
epoch = checkpoint['epoch']
|
274 |
-
f1 = checkpoint['f1']
|
275 |
-
return epoch, f1
|
276 |
-
|
277 |
-
def update_lr(self, factor=2):
|
278 |
-
if factor:
|
279 |
-
new_lr = self.lr / factor
|
280 |
-
|
281 |
-
if self.context:
|
282 |
-
self.optimizer_sem = torch.optim.Adam(params=list(self.model.encoder_sem.parameters()) + list(self.model.decoder_sem.parameters()),
|
283 |
-
lr=new_lr, weight_decay=0.0001)
|
284 |
-
self.optimizer_part = torch.optim.Adam(
|
285 |
-
params=list(self.model.encoder_part.parameters()) + list(self.model.decoder_part.parameters()), lr=new_lr, weight_decay=0.0001)
|
286 |
-
self.optimizer_contact = torch.optim.Adam(
|
287 |
-
params=list(self.model.encoder_sem.parameters()) + list(self.model.encoder_part.parameters()) + list(
|
288 |
-
self.model.cross_att.parameters()) + list(self.model.classif.parameters()), lr=new_lr, weight_decay=0.0001)
|
289 |
-
|
290 |
-
print('update learning rate: %f -> %f' % (self.lr, new_lr))
|
291 |
-
self.lr = new_lr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/__init__.py
DELETED
File without changes
|
utils/cluster.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import stat
|
4 |
-
import shutil
|
5 |
-
import subprocess
|
6 |
-
|
7 |
-
from loguru import logger
|
8 |
-
|
9 |
-
GPUS = {
|
10 |
-
'v100-v16': ('\"Tesla V100-PCIE-16GB\"', 'tesla', 16000),
|
11 |
-
'v100-p32': ('\"Tesla V100-PCIE-32GB\"', 'tesla', 32000),
|
12 |
-
'v100-s32': ('\"Tesla V100-SXM2-32GB\"', 'tesla', 32000),
|
13 |
-
'v100-p16': ('\"Tesla P100-PCIE-16GB\"', 'tesla', 16000),
|
14 |
-
}
|
15 |
-
|
16 |
-
def get_gpus(min_mem=10000, arch=('tesla', 'quadro', 'rtx')):
|
17 |
-
gpu_names = []
|
18 |
-
for k, (gpu_name, gpu_arch, gpu_mem) in GPUS.items():
|
19 |
-
if gpu_mem >= min_mem and gpu_arch in arch:
|
20 |
-
gpu_names.append(gpu_name)
|
21 |
-
|
22 |
-
assert len(gpu_names) > 0, 'Suitable GPU model could not be found'
|
23 |
-
|
24 |
-
return gpu_names
|
25 |
-
|
26 |
-
|
27 |
-
def execute_task_on_cluster(
|
28 |
-
script,
|
29 |
-
exp_name,
|
30 |
-
output_dir,
|
31 |
-
condor_dir,
|
32 |
-
cfg_file,
|
33 |
-
num_exp=1,
|
34 |
-
exp_opts=None,
|
35 |
-
bid_amount=10,
|
36 |
-
num_workers=2,
|
37 |
-
memory=64000,
|
38 |
-
gpu_min_mem=10000,
|
39 |
-
gpu_arch=('tesla', 'quadro', 'rtx'),
|
40 |
-
num_gpus=1
|
41 |
-
):
|
42 |
-
# copy config to a new experiment directory and source from there.
|
43 |
-
# this makes sure the correct config is copied even if you change the config file
|
44 |
-
# after starting the experiment and before the first job is submitted
|
45 |
-
temp_config_dir = os.path.join(os.path.dirname(condor_dir), 'temp_configs', exp_name)
|
46 |
-
os.makedirs(temp_config_dir, exist_ok=True)
|
47 |
-
new_cfg_file = os.path.join(temp_config_dir, 'config.yaml')
|
48 |
-
shutil.copy(src=cfg_file, dst=new_cfg_file)
|
49 |
-
|
50 |
-
gpus = get_gpus(min_mem=gpu_min_mem, arch=gpu_arch)
|
51 |
-
|
52 |
-
gpus = ' || '.join([f'CUDADeviceName=={x}' for x in gpus])
|
53 |
-
|
54 |
-
condor_log_dir = os.path.join(condor_dir, 'condorlog', exp_name)
|
55 |
-
os.makedirs(condor_log_dir, exist_ok=True)
|
56 |
-
submission = f'executable = {condor_log_dir}/{exp_name}_run.sh\n' \
|
57 |
-
'arguments = $(Process) $(Cluster)\n' \
|
58 |
-
f'error = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).err\n' \
|
59 |
-
f'output = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).out\n' \
|
60 |
-
f'log = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).log\n' \
|
61 |
-
f'request_memory = {memory}\n' \
|
62 |
-
f'request_cpus={int(num_workers)}\n' \
|
63 |
-
f'request_gpus={num_gpus}\n' \
|
64 |
-
f'requirements={gpus}\n' \
|
65 |
-
f'+MaxRunningPrice = 500\n' \
|
66 |
-
f'queue {num_exp}'
|
67 |
-
# f'request_cpus={int(num_workers/2)}\n' \
|
68 |
-
# f'+RunningPriceExceededAction = \"kill\"\n' \
|
69 |
-
print('<<< Condor Submission >>> ')
|
70 |
-
print(submission)
|
71 |
-
|
72 |
-
with open(f'{condor_log_dir}/{exp_name}_submit.sub', 'w') as f:
|
73 |
-
f.write(submission)
|
74 |
-
|
75 |
-
# output_dir = os.path.join(output_dir, exp_name)
|
76 |
-
logger.info(f'The logs for this experiments can be found under: {condor_log_dir}')
|
77 |
-
logger.info(f'The outputs for this experiments can be found under: {output_dir}')
|
78 |
-
## This is the trick. Notice there is no --cluster here
|
79 |
-
bash = 'export PYTHONBUFFERED=1\n export PATH=$PATH\n ' \
|
80 |
-
f'{sys.executable} {script} --cfg {new_cfg_file} --cfg_id $1'
|
81 |
-
|
82 |
-
if exp_opts is not None:
|
83 |
-
bash += ' --opts '
|
84 |
-
for opt in exp_opts:
|
85 |
-
bash += f'{opt} '
|
86 |
-
bash += 'SYSTEM.CLUSTER_NODE $2.$1'
|
87 |
-
else:
|
88 |
-
bash += ' --opts SYSTEM.CLUSTER_NODE $2.$1'
|
89 |
-
|
90 |
-
executable_path = f'{condor_log_dir}/{exp_name}_run.sh'
|
91 |
-
|
92 |
-
with open(executable_path, 'w') as f:
|
93 |
-
f.write(bash)
|
94 |
-
|
95 |
-
os.chmod(executable_path, stat.S_IRWXU)
|
96 |
-
|
97 |
-
cmd = ['condor_submit_bid', f'{bid_amount}', f'{condor_log_dir}/{exp_name}_submit.sub']
|
98 |
-
logger.info('Executing ' + ' '.join(cmd))
|
99 |
-
subprocess.call(cmd)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/colorwheel.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
import cv2
|
2 |
-
import numpy as np
|
3 |
-
|
4 |
-
|
5 |
-
def make_color_wheel_image(img_width, img_height):
|
6 |
-
"""
|
7 |
-
Creates a color wheel based image of given width and height
|
8 |
-
Args:
|
9 |
-
img_width (int):
|
10 |
-
img_height (int):
|
11 |
-
|
12 |
-
Returns:
|
13 |
-
opencv image (numpy array): color wheel based image
|
14 |
-
"""
|
15 |
-
hue = np.fromfunction(lambda i, j: (np.arctan2(i-img_height/2, img_width/2-j) + np.pi)*(180/np.pi)/2,
|
16 |
-
(img_height, img_width), dtype=np.float)
|
17 |
-
saturation = np.ones((img_height, img_width)) * 255
|
18 |
-
value = np.ones((img_height, img_width)) * 255
|
19 |
-
hsl = np.dstack((hue, saturation, value))
|
20 |
-
color_map = cv2.cvtColor(np.array(hsl, dtype=np.uint8), cv2.COLOR_HSV2BGR)
|
21 |
-
return color_map
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/config.py
DELETED
@@ -1,196 +0,0 @@
|
|
1 |
-
import itertools
|
2 |
-
import operator
|
3 |
-
import os
|
4 |
-
import shutil
|
5 |
-
import time
|
6 |
-
from functools import reduce
|
7 |
-
from typing import List, Union
|
8 |
-
|
9 |
-
import configargparse
|
10 |
-
import yaml
|
11 |
-
from flatten_dict import flatten, unflatten
|
12 |
-
from loguru import logger
|
13 |
-
from yacs.config import CfgNode as CN
|
14 |
-
|
15 |
-
from utils.cluster import execute_task_on_cluster
|
16 |
-
from utils.default_hparams import hparams
|
17 |
-
|
18 |
-
|
19 |
-
def parse_args():
|
20 |
-
def add_common_cmdline_args(parser):
|
21 |
-
# for cluster runs
|
22 |
-
parser.add_argument('--cfg', required=True, type=str, help='cfg file path')
|
23 |
-
parser.add_argument('--opts', default=[], nargs='*', help='additional options to update config')
|
24 |
-
parser.add_argument('--cfg_id', type=int, default=0, help='cfg id to run when multiple experiments are spawned')
|
25 |
-
parser.add_argument('--cluster', default=False, action='store_true', help='creates submission files for cluster')
|
26 |
-
parser.add_argument('--bid', type=int, default=10, help='amount of bid for cluster')
|
27 |
-
parser.add_argument('--memory', type=int, default=64000, help='memory amount for cluster')
|
28 |
-
parser.add_argument('--gpu_min_mem', type=int, default=12000, help='minimum amount of GPU memory')
|
29 |
-
parser.add_argument('--gpu_arch', default=['tesla', 'quadro', 'rtx'],
|
30 |
-
nargs='*', help='additional options to update config')
|
31 |
-
parser.add_argument('--num_cpus', type=int, default=8, help='num cpus for cluster')
|
32 |
-
return parser
|
33 |
-
|
34 |
-
# For Blender main parser
|
35 |
-
arg_formatter = configargparse.ArgumentDefaultsHelpFormatter
|
36 |
-
cfg_parser = configargparse.YAMLConfigFileParser
|
37 |
-
description = 'PyTorch implementation of DECO'
|
38 |
-
|
39 |
-
parser = configargparse.ArgumentParser(formatter_class=arg_formatter,
|
40 |
-
config_file_parser_class=cfg_parser,
|
41 |
-
description=description,
|
42 |
-
prog='deco')
|
43 |
-
|
44 |
-
parser = add_common_cmdline_args(parser)
|
45 |
-
|
46 |
-
args = parser.parse_args()
|
47 |
-
print(args, end='\n\n')
|
48 |
-
|
49 |
-
return args
|
50 |
-
|
51 |
-
def get_hparams_defaults():
|
52 |
-
"""Get a yacs hparamsNode object with default values for my_project."""
|
53 |
-
# Return a clone so that the defaults will not be altered
|
54 |
-
# This is for the "local variable" use pattern
|
55 |
-
return hparams.clone()
|
56 |
-
|
57 |
-
def update_hparams(hparams_file):
|
58 |
-
hparams = get_hparams_defaults()
|
59 |
-
hparams.merge_from_file(hparams_file)
|
60 |
-
return hparams.clone()
|
61 |
-
|
62 |
-
def update_hparams_from_dict(cfg_dict):
|
63 |
-
hparams = get_hparams_defaults()
|
64 |
-
cfg = hparams.load_cfg(str(cfg_dict))
|
65 |
-
hparams.merge_from_other_cfg(cfg)
|
66 |
-
return hparams.clone()
|
67 |
-
|
68 |
-
def get_grid_search_configs(config, excluded_keys=[]):
|
69 |
-
"""
|
70 |
-
:param config: dictionary with the configurations
|
71 |
-
:return: The different configurations
|
72 |
-
"""
|
73 |
-
|
74 |
-
def bool_to_string(x: Union[List[bool], bool]) -> Union[List[str], str]:
|
75 |
-
"""
|
76 |
-
boolean to string conversion
|
77 |
-
:param x: list or bool to be converted
|
78 |
-
:return: string converted thinghat
|
79 |
-
"""
|
80 |
-
if isinstance(x, bool):
|
81 |
-
return [str(x)]
|
82 |
-
for i, j in enumerate(x):
|
83 |
-
x[i] = str(j)
|
84 |
-
return x
|
85 |
-
|
86 |
-
# exclude from grid search
|
87 |
-
|
88 |
-
flattened_config_dict = flatten(config, reducer='path')
|
89 |
-
hyper_params = []
|
90 |
-
|
91 |
-
for k,v in flattened_config_dict.items():
|
92 |
-
if isinstance(v,list):
|
93 |
-
if k in excluded_keys:
|
94 |
-
flattened_config_dict[k] = ['+'.join(v)]
|
95 |
-
elif len(v) > 1:
|
96 |
-
hyper_params += [k]
|
97 |
-
|
98 |
-
if isinstance(v, list) and isinstance(v[0], bool) :
|
99 |
-
flattened_config_dict[k] = bool_to_string(v)
|
100 |
-
|
101 |
-
if not isinstance(v,list):
|
102 |
-
if isinstance(v, bool):
|
103 |
-
flattened_config_dict[k] = bool_to_string(v)
|
104 |
-
else:
|
105 |
-
flattened_config_dict[k] = [v]
|
106 |
-
|
107 |
-
keys, values = zip(*flattened_config_dict.items())
|
108 |
-
experiments = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
109 |
-
|
110 |
-
for exp_id, exp in enumerate(experiments):
|
111 |
-
for param in excluded_keys:
|
112 |
-
exp[param] = exp[param].strip().split('+')
|
113 |
-
for param_name, param_value in exp.items():
|
114 |
-
# print(param_name,type(param_value))
|
115 |
-
if isinstance(param_value, list) and (param_value[0] in ['True', 'False']):
|
116 |
-
exp[param_name] = [True if x == 'True' else False for x in param_value]
|
117 |
-
if param_value in ['True', 'False']:
|
118 |
-
if param_value == 'True':
|
119 |
-
exp[param_name] = True
|
120 |
-
else:
|
121 |
-
exp[param_name] = False
|
122 |
-
|
123 |
-
|
124 |
-
experiments[exp_id] = unflatten(exp, splitter='path')
|
125 |
-
|
126 |
-
return experiments, hyper_params
|
127 |
-
|
128 |
-
def get_from_dict(dict, keys):
|
129 |
-
return reduce(operator.getitem, keys, dict)
|
130 |
-
|
131 |
-
def save_dict_to_yaml(obj, filename, mode='w'):
|
132 |
-
with open(filename, mode) as f:
|
133 |
-
yaml.dump(obj, f, default_flow_style=False)
|
134 |
-
|
135 |
-
def run_grid_search_experiments(
|
136 |
-
args,
|
137 |
-
script='train.py',
|
138 |
-
change_wt_name=True
|
139 |
-
):
|
140 |
-
cfg = yaml.safe_load(open(args.cfg))
|
141 |
-
# parse config file to split into a list of configs with tuning hyperparameters separated
|
142 |
-
# Also return the names of tuned hyperparameters hyperparameters
|
143 |
-
different_configs, hyperparams = get_grid_search_configs(
|
144 |
-
cfg,
|
145 |
-
excluded_keys=['TRAINING/DATASETS', 'TRAINING/DATASET_MIX_PDF', 'VALIDATION/DATASETS'],
|
146 |
-
)
|
147 |
-
logger.info(f'Grid search hparams: \n {hyperparams}')
|
148 |
-
|
149 |
-
# The config file may be missing some default values, so we need to add them
|
150 |
-
different_configs = [update_hparams_from_dict(c) for c in different_configs]
|
151 |
-
logger.info(f'======> Number of experiment configurations is {len(different_configs)}')
|
152 |
-
|
153 |
-
config_to_run = CN(different_configs[args.cfg_id])
|
154 |
-
|
155 |
-
if args.cluster:
|
156 |
-
execute_task_on_cluster(
|
157 |
-
script=script,
|
158 |
-
exp_name=config_to_run.EXP_NAME,
|
159 |
-
output_dir=config_to_run.OUTPUT_DIR,
|
160 |
-
condor_dir=config_to_run.CONDOR_DIR,
|
161 |
-
cfg_file=args.cfg,
|
162 |
-
num_exp=len(different_configs),
|
163 |
-
bid_amount=args.bid,
|
164 |
-
num_workers=config_to_run.DATASET.NUM_WORKERS,
|
165 |
-
memory=args.memory,
|
166 |
-
exp_opts=args.opts,
|
167 |
-
gpu_min_mem=args.gpu_min_mem,
|
168 |
-
gpu_arch=args.gpu_arch,
|
169 |
-
)
|
170 |
-
exit()
|
171 |
-
|
172 |
-
# ==== create logdir using hyperparam settings
|
173 |
-
logtime = time.strftime('%d-%m-%Y_%H-%M-%S')
|
174 |
-
logdir = f'{logtime}_{config_to_run.EXP_NAME}'
|
175 |
-
wt_file = config_to_run.EXP_NAME + '_'
|
176 |
-
for hp in hyperparams:
|
177 |
-
v = get_from_dict(different_configs[args.cfg_id], hp.split('/'))
|
178 |
-
logdir += f'_{hp.replace("/", ".").replace("_", "").lower()}-{v}'
|
179 |
-
wt_file += f'{hp.replace("/", ".").replace("_", "").lower()}-{v}_'
|
180 |
-
logdir = os.path.join(config_to_run.OUTPUT_DIR, logdir)
|
181 |
-
os.makedirs(logdir, exist_ok=True)
|
182 |
-
config_to_run.LOGDIR = logdir
|
183 |
-
|
184 |
-
wt_file += 'best.pth'
|
185 |
-
wt_path = os.path.join(os.path.dirname(config_to_run.TRAINING.BEST_MODEL_PATH), wt_file)
|
186 |
-
if change_wt_name: config_to_run.TRAINING.BEST_MODEL_PATH = wt_path
|
187 |
-
|
188 |
-
shutil.copy(src=args.cfg, dst=os.path.join(logdir, 'config.yaml'))
|
189 |
-
|
190 |
-
# save config
|
191 |
-
save_dict_to_yaml(
|
192 |
-
unflatten(flatten(config_to_run)),
|
193 |
-
os.path.join(config_to_run.LOGDIR, 'config_to_run.yaml')
|
194 |
-
)
|
195 |
-
|
196 |
-
return config_to_run
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/default_hparams.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
from yacs.config import CfgNode as CN
|
2 |
-
|
3 |
-
# Set default hparams to construct new default config
|
4 |
-
# Make sure the defaults are same as in parser
|
5 |
-
hparams = CN()
|
6 |
-
|
7 |
-
# General settings
|
8 |
-
hparams.EXP_NAME = 'default'
|
9 |
-
hparams.PROJECT_NAME = 'default'
|
10 |
-
hparams.OUTPUT_DIR = 'deco_results/'
|
11 |
-
hparams.CONDOR_DIR = '/is/cluster/work/achatterjee/condor/rich/'
|
12 |
-
hparams.LOGDIR = ''
|
13 |
-
|
14 |
-
# Dataset hparams
|
15 |
-
hparams.DATASET = CN()
|
16 |
-
hparams.DATASET.BATCH_SIZE = 64
|
17 |
-
hparams.DATASET.NUM_WORKERS = 4
|
18 |
-
hparams.DATASET.NORMALIZE_IMAGES = True
|
19 |
-
|
20 |
-
# Optimizer hparams
|
21 |
-
hparams.OPTIMIZER = CN()
|
22 |
-
hparams.OPTIMIZER.TYPE = 'adam'
|
23 |
-
hparams.OPTIMIZER.LR = 5e-5
|
24 |
-
hparams.OPTIMIZER.NUM_UPDATE_LR = 10
|
25 |
-
|
26 |
-
# Training hparams
|
27 |
-
hparams.TRAINING = CN()
|
28 |
-
hparams.TRAINING.ENCODER = 'hrnet'
|
29 |
-
hparams.TRAINING.CONTEXT = True
|
30 |
-
hparams.TRAINING.NUM_EPOCHS = 50
|
31 |
-
hparams.TRAINING.SUMMARY_STEPS = 100
|
32 |
-
hparams.TRAINING.CHECKPOINT_EPOCHS = 5
|
33 |
-
hparams.TRAINING.NUM_EARLY_STOP = 10
|
34 |
-
hparams.TRAINING.DATASETS = ['rich']
|
35 |
-
hparams.TRAINING.DATASET_MIX_PDF = ['1.']
|
36 |
-
hparams.TRAINING.DATASET_ROOT_PATH = '/is/cluster/work/achatterjee/rich/npzs'
|
37 |
-
hparams.TRAINING.BEST_MODEL_PATH = '/is/cluster/work/achatterjee/weights/rich/exp/rich_exp.pth'
|
38 |
-
hparams.TRAINING.LOSS_WEIGHTS = 1.
|
39 |
-
hparams.TRAINING.PAL_LOSS_WEIGHTS = 1.
|
40 |
-
|
41 |
-
# Training hparams
|
42 |
-
hparams.VALIDATION = CN()
|
43 |
-
hparams.VALIDATION.SUMMARY_STEPS = 100
|
44 |
-
hparams.VALIDATION.DATASETS = ['rich']
|
45 |
-
hparams.VALIDATION.MAIN_DATASET = 'rich'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/diff_renderer.py
DELETED
@@ -1,287 +0,0 @@
|
|
1 |
-
# from https://gitlab.tuebingen.mpg.de/mkocabas/projects/-/blob/master/pare/pare/utils/diff_renderer.py
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import numpy as np
|
5 |
-
import torch.nn as nn
|
6 |
-
|
7 |
-
from pytorch3d.renderer import (
|
8 |
-
PerspectiveCameras,
|
9 |
-
RasterizationSettings,
|
10 |
-
DirectionalLights,
|
11 |
-
BlendParams,
|
12 |
-
HardFlatShader,
|
13 |
-
MeshRasterizer,
|
14 |
-
TexturesVertex,
|
15 |
-
TexturesAtlas
|
16 |
-
)
|
17 |
-
from pytorch3d.structures import Meshes
|
18 |
-
|
19 |
-
from .image_utils import get_default_camera
|
20 |
-
from .smpl_uv import get_tenet_texture
|
21 |
-
|
22 |
-
|
23 |
-
class MeshRendererWithDepth(nn.Module):
|
24 |
-
"""
|
25 |
-
A class for rendering a batch of heterogeneous meshes. The class should
|
26 |
-
be initialized with a rasterizer and shader class which each have a forward
|
27 |
-
function.
|
28 |
-
"""
|
29 |
-
|
30 |
-
def __init__(self, rasterizer, shader):
|
31 |
-
super().__init__()
|
32 |
-
self.rasterizer = rasterizer
|
33 |
-
self.shader = shader
|
34 |
-
|
35 |
-
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
|
36 |
-
"""
|
37 |
-
Render a batch of images from a batch of meshes by rasterizing and then
|
38 |
-
shading.
|
39 |
-
|
40 |
-
NOTE: If the blur radius for rasterization is > 0.0, some pixels can
|
41 |
-
have one or more barycentric coordinates lying outside the range [0, 1].
|
42 |
-
For a pixel with out of bounds barycentric coordinates with respect to a
|
43 |
-
face f, clipping is required before interpolating the texture uv
|
44 |
-
coordinates and z buffer so that the colors and depths are limited to
|
45 |
-
the range for the corresponding face.
|
46 |
-
"""
|
47 |
-
fragments = self.rasterizer(meshes_world, **kwargs)
|
48 |
-
images = self.shader(fragments, meshes_world, **kwargs)
|
49 |
-
|
50 |
-
mask = (fragments.zbuf > -1).float()
|
51 |
-
|
52 |
-
zbuf = fragments.zbuf.view(images.shape[0], -1)
|
53 |
-
# print(images.shape, zbuf.shape)
|
54 |
-
depth = (zbuf - zbuf.min(-1, keepdims=True).values) / \
|
55 |
-
(zbuf.max(-1, keepdims=True).values - zbuf.min(-1, keepdims=True).values)
|
56 |
-
depth = depth.reshape(*images.shape[:3] + (1,))
|
57 |
-
|
58 |
-
images = torch.cat([images[:, :, :, :3], mask, depth], dim=-1)
|
59 |
-
return images
|
60 |
-
|
61 |
-
|
62 |
-
class DifferentiableRenderer(nn.Module):
|
63 |
-
def __init__(
|
64 |
-
self,
|
65 |
-
img_h,
|
66 |
-
img_w,
|
67 |
-
focal_length,
|
68 |
-
device='cuda',
|
69 |
-
background_color=(0.0, 0.0, 0.0),
|
70 |
-
texture_mode='smplpix',
|
71 |
-
vertex_colors=None,
|
72 |
-
face_textures=None,
|
73 |
-
smpl_faces=None,
|
74 |
-
is_train=False,
|
75 |
-
is_cam_batch=False,
|
76 |
-
):
|
77 |
-
super(DifferentiableRenderer, self).__init__()
|
78 |
-
self.x = 'a'
|
79 |
-
self.img_h = img_h
|
80 |
-
self.img_w = img_w
|
81 |
-
self.device = device
|
82 |
-
self.focal_length = focal_length
|
83 |
-
K, R = get_default_camera(focal_length, img_h, img_w, is_cam_batch=is_cam_batch)
|
84 |
-
K, R = K.to(device), R.to(device)
|
85 |
-
|
86 |
-
# T = torch.tensor([[0, 0, 2.5 * self.focal_length / max(self.img_h, self.img_w)]]).to(device)
|
87 |
-
if is_cam_batch:
|
88 |
-
T = torch.zeros((K.shape[0], 3)).to(device)
|
89 |
-
else:
|
90 |
-
T = torch.tensor([[0.0, 0.0, 0.0]]).to(device)
|
91 |
-
self.background_color = background_color
|
92 |
-
self.renderer = None
|
93 |
-
smpl_faces = smpl_faces
|
94 |
-
|
95 |
-
if texture_mode == 'smplpix':
|
96 |
-
face_colors = get_tenet_texture(mode=texture_mode).to(device).float()
|
97 |
-
vertex_colors = torch.from_numpy(
|
98 |
-
np.load(f'data/smpl/{texture_mode}_vertex_colors.npy')[:,:3]
|
99 |
-
).unsqueeze(0).to(device).float()
|
100 |
-
if texture_mode == 'partseg':
|
101 |
-
vertex_colors = vertex_colors[..., :3].unsqueeze(0).to(device)
|
102 |
-
face_colors = face_textures.to(device)
|
103 |
-
if texture_mode == 'deco':
|
104 |
-
vertex_colors = vertex_colors[..., :3].to(device)
|
105 |
-
face_colors = face_textures.to(device)
|
106 |
-
|
107 |
-
self.register_buffer('K', K)
|
108 |
-
self.register_buffer('R', R)
|
109 |
-
self.register_buffer('T', T)
|
110 |
-
self.register_buffer('face_colors', face_colors)
|
111 |
-
self.register_buffer('vertex_colors', vertex_colors)
|
112 |
-
self.register_buffer('smpl_faces', smpl_faces)
|
113 |
-
|
114 |
-
self.set_requires_grad(is_train)
|
115 |
-
|
116 |
-
def set_requires_grad(self, val=False):
|
117 |
-
self.K.requires_grad_(val)
|
118 |
-
self.R.requires_grad_(val)
|
119 |
-
self.T.requires_grad_(val)
|
120 |
-
self.face_colors.requires_grad_(val)
|
121 |
-
self.vertex_colors.requires_grad_(val)
|
122 |
-
# check if smpl_faces is a FloatTensor as requires_grad_ is not defined for LongTensor
|
123 |
-
if isinstance(self.smpl_faces, torch.FloatTensor):
|
124 |
-
self.smpl_faces.requires_grad_(val)
|
125 |
-
|
126 |
-
def forward(self, vertices, faces=None, R=None, T=None):
|
127 |
-
raise NotImplementedError
|
128 |
-
|
129 |
-
|
130 |
-
class Pytorch3D(DifferentiableRenderer):
|
131 |
-
def __init__(
|
132 |
-
self,
|
133 |
-
img_h,
|
134 |
-
img_w,
|
135 |
-
focal_length,
|
136 |
-
device='cuda',
|
137 |
-
background_color=(0.0, 0.0, 0.0),
|
138 |
-
texture_mode='smplpix',
|
139 |
-
vertex_colors=None,
|
140 |
-
face_textures=None,
|
141 |
-
smpl_faces=None,
|
142 |
-
model_type='smpl',
|
143 |
-
is_train=False,
|
144 |
-
is_cam_batch=False,
|
145 |
-
):
|
146 |
-
super(Pytorch3D, self).__init__(
|
147 |
-
img_h,
|
148 |
-
img_w,
|
149 |
-
focal_length,
|
150 |
-
device=device,
|
151 |
-
background_color=background_color,
|
152 |
-
texture_mode=texture_mode,
|
153 |
-
vertex_colors=vertex_colors,
|
154 |
-
face_textures=face_textures,
|
155 |
-
smpl_faces=smpl_faces,
|
156 |
-
is_train=is_train,
|
157 |
-
is_cam_batch=is_cam_batch,
|
158 |
-
)
|
159 |
-
|
160 |
-
# this R converts the camera from pyrender NDC to
|
161 |
-
# OpenGL coordinate frame. It is basicall R(180, X) x R(180, Y)
|
162 |
-
# I manually defined it here for convenience
|
163 |
-
self.R = self.R @ torch.tensor(
|
164 |
-
[[[ -1.0, 0.0, 0.0],
|
165 |
-
[ 0.0, -1.0, 0.0],
|
166 |
-
[ 0.0, 0.0, 1.0]]],
|
167 |
-
dtype=self.R.dtype, device=self.R.device,
|
168 |
-
)
|
169 |
-
|
170 |
-
if is_cam_batch:
|
171 |
-
focal_length = self.focal_length
|
172 |
-
else:
|
173 |
-
focal_length = self.focal_length[None, :]
|
174 |
-
|
175 |
-
principal_point = ((self.img_w // 2, self.img_h // 2),)
|
176 |
-
image_size = ((self.img_h, self.img_w),)
|
177 |
-
|
178 |
-
cameras = PerspectiveCameras(
|
179 |
-
device=self.device,
|
180 |
-
focal_length=focal_length,
|
181 |
-
principal_point=principal_point,
|
182 |
-
R=self.R,
|
183 |
-
T=self.T,
|
184 |
-
in_ndc=False,
|
185 |
-
image_size=image_size,
|
186 |
-
)
|
187 |
-
|
188 |
-
for param in cameras.parameters():
|
189 |
-
param.requires_grad_(False)
|
190 |
-
|
191 |
-
raster_settings = RasterizationSettings(
|
192 |
-
image_size=(self.img_h, self.img_w),
|
193 |
-
blur_radius=0.0,
|
194 |
-
max_faces_per_bin=20000,
|
195 |
-
faces_per_pixel=1,
|
196 |
-
)
|
197 |
-
|
198 |
-
lights = DirectionalLights(
|
199 |
-
device=self.device,
|
200 |
-
ambient_color=((1.0, 1.0, 1.0),),
|
201 |
-
diffuse_color=((0.0, 0.0, 0.0),),
|
202 |
-
specular_color=((0.0, 0.0, 0.0),),
|
203 |
-
direction=((0, 1, 0),),
|
204 |
-
)
|
205 |
-
|
206 |
-
blend_params = BlendParams(background_color=self.background_color)
|
207 |
-
|
208 |
-
shader = HardFlatShader(device=self.device,
|
209 |
-
cameras=cameras,
|
210 |
-
blend_params=blend_params,
|
211 |
-
lights=lights)
|
212 |
-
|
213 |
-
self.textures = TexturesVertex(verts_features=self.vertex_colors)
|
214 |
-
|
215 |
-
self.renderer = MeshRendererWithDepth(
|
216 |
-
rasterizer=MeshRasterizer(
|
217 |
-
cameras=cameras,
|
218 |
-
raster_settings=raster_settings
|
219 |
-
),
|
220 |
-
shader=shader,
|
221 |
-
)
|
222 |
-
|
223 |
-
def forward(self, vertices, faces=None, R=None, T=None, face_atlas=None):
|
224 |
-
batch_size = vertices.shape[0]
|
225 |
-
if faces is None:
|
226 |
-
faces = self.smpl_faces.expand(batch_size, -1, -1)
|
227 |
-
|
228 |
-
if R is None:
|
229 |
-
R = self.R.expand(batch_size, -1, -1)
|
230 |
-
|
231 |
-
if T is None:
|
232 |
-
T = self.T.expand(batch_size, -1)
|
233 |
-
|
234 |
-
# convert camera translation to pytorch3d coordinate frame
|
235 |
-
T = torch.bmm(R, T.unsqueeze(-1)).squeeze(-1)
|
236 |
-
|
237 |
-
vertex_textures = TexturesVertex(
|
238 |
-
verts_features=self.vertex_colors.expand(batch_size, -1, -1)
|
239 |
-
)
|
240 |
-
|
241 |
-
# face_textures needed because vertex_texture cause interpolation at boundaries
|
242 |
-
if face_atlas:
|
243 |
-
face_textures = TexturesAtlas(atlas=face_atlas)
|
244 |
-
else:
|
245 |
-
face_textures = TexturesAtlas(atlas=self.face_colors)
|
246 |
-
|
247 |
-
# we may need to rotate the mesh
|
248 |
-
meshes = Meshes(verts=vertices, faces=faces, textures=face_textures)
|
249 |
-
images = self.renderer(meshes, R=R, T=T)
|
250 |
-
images = images.permute(0, 3, 1, 2)
|
251 |
-
return images
|
252 |
-
|
253 |
-
|
254 |
-
class NeuralMeshRenderer(DifferentiableRenderer):
|
255 |
-
def __init__(self, *args, **kwargs):
|
256 |
-
import neural_renderer as nr
|
257 |
-
|
258 |
-
super(NeuralMeshRenderer, self).__init__(*args, **kwargs)
|
259 |
-
|
260 |
-
self.neural_renderer = nr.Renderer(
|
261 |
-
dist_coeffs=None,
|
262 |
-
orig_size=self.img_size,
|
263 |
-
image_size=self.img_size,
|
264 |
-
light_intensity_ambient=1,
|
265 |
-
light_intensity_directional=0,
|
266 |
-
anti_aliasing=False,
|
267 |
-
)
|
268 |
-
|
269 |
-
def forward(self, vertices, faces=None, R=None, T=None):
|
270 |
-
batch_size = vertices.shape[0]
|
271 |
-
if faces is None:
|
272 |
-
faces = self.smpl_faces.expand(batch_size, -1, -1)
|
273 |
-
|
274 |
-
if R is None:
|
275 |
-
R = self.R.expand(batch_size, -1, -1)
|
276 |
-
|
277 |
-
if T is None:
|
278 |
-
T = self.T.expand(batch_size, -1)
|
279 |
-
rgb, depth, mask = self.neural_renderer(
|
280 |
-
vertices,
|
281 |
-
faces,
|
282 |
-
textures=self.face_colors.expand(batch_size, -1, -1, -1, -1, -1),
|
283 |
-
K=self.K.expand(batch_size, -1, -1),
|
284 |
-
R=R,
|
285 |
-
t=T.unsqueeze(1),
|
286 |
-
)
|
287 |
-
return torch.cat([rgb, depth.unsqueeze(1), mask.unsqueeze(1)], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/get_cfg.py
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
from yacs.config import CfgNode
|
2 |
-
|
3 |
-
_VALID_TYPES = {tuple, list, str, int, float, bool}
|
4 |
-
|
5 |
-
|
6 |
-
def convert_to_dict(cfg_node, key_list=[]):
|
7 |
-
""" Convert a config node to dictionary """
|
8 |
-
if not isinstance(cfg_node, CfgNode):
|
9 |
-
if type(cfg_node) not in _VALID_TYPES:
|
10 |
-
print("Key {} with value {} is not a valid type; valid types: {}".format(
|
11 |
-
".".join(key_list), type(cfg_node), _VALID_TYPES), )
|
12 |
-
return cfg_node
|
13 |
-
else:
|
14 |
-
cfg_dict = dict(cfg_node)
|
15 |
-
for k, v in cfg_dict.items():
|
16 |
-
cfg_dict[k] = convert_to_dict(v, key_list + [k])
|
17 |
-
return cfg_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/hrnet.py
DELETED
@@ -1,625 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
from loguru import logger
|
6 |
-
import torch.nn.functional as F
|
7 |
-
from yacs.config import CfgNode as CN
|
8 |
-
|
9 |
-
models = [
|
10 |
-
'hrnet_w32',
|
11 |
-
'hrnet_w48',
|
12 |
-
]
|
13 |
-
|
14 |
-
BN_MOMENTUM = 0.1
|
15 |
-
|
16 |
-
|
17 |
-
def conv3x3(in_planes, out_planes, stride=1):
|
18 |
-
"""3x3 convolution with padding"""
|
19 |
-
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
20 |
-
padding=1, bias=False)
|
21 |
-
|
22 |
-
|
23 |
-
class BasicBlock(nn.Module):
|
24 |
-
expansion = 1
|
25 |
-
|
26 |
-
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
27 |
-
super(BasicBlock, self).__init__()
|
28 |
-
self.conv1 = conv3x3(inplanes, planes, stride)
|
29 |
-
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
30 |
-
self.relu = nn.ReLU(inplace=True)
|
31 |
-
self.conv2 = conv3x3(planes, planes)
|
32 |
-
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
33 |
-
self.downsample = downsample
|
34 |
-
self.stride = stride
|
35 |
-
|
36 |
-
def forward(self, x):
|
37 |
-
residual = x
|
38 |
-
|
39 |
-
out = self.conv1(x)
|
40 |
-
out = self.bn1(out)
|
41 |
-
out = self.relu(out)
|
42 |
-
|
43 |
-
out = self.conv2(out)
|
44 |
-
out = self.bn2(out)
|
45 |
-
|
46 |
-
if self.downsample is not None:
|
47 |
-
residual = self.downsample(x)
|
48 |
-
|
49 |
-
out += residual
|
50 |
-
out = self.relu(out)
|
51 |
-
|
52 |
-
return out
|
53 |
-
|
54 |
-
|
55 |
-
class Bottleneck(nn.Module):
|
56 |
-
expansion = 4
|
57 |
-
|
58 |
-
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
59 |
-
super(Bottleneck, self).__init__()
|
60 |
-
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
61 |
-
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
62 |
-
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
63 |
-
padding=1, bias=False)
|
64 |
-
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
65 |
-
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
|
66 |
-
bias=False)
|
67 |
-
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
|
68 |
-
momentum=BN_MOMENTUM)
|
69 |
-
self.relu = nn.ReLU(inplace=True)
|
70 |
-
self.downsample = downsample
|
71 |
-
self.stride = stride
|
72 |
-
|
73 |
-
def forward(self, x):
|
74 |
-
residual = x
|
75 |
-
|
76 |
-
out = self.conv1(x)
|
77 |
-
out = self.bn1(out)
|
78 |
-
out = self.relu(out)
|
79 |
-
|
80 |
-
out = self.conv2(out)
|
81 |
-
out = self.bn2(out)
|
82 |
-
out = self.relu(out)
|
83 |
-
|
84 |
-
out = self.conv3(out)
|
85 |
-
out = self.bn3(out)
|
86 |
-
|
87 |
-
if self.downsample is not None:
|
88 |
-
residual = self.downsample(x)
|
89 |
-
|
90 |
-
out += residual
|
91 |
-
out = self.relu(out)
|
92 |
-
|
93 |
-
return out
|
94 |
-
|
95 |
-
|
96 |
-
class HighResolutionModule(nn.Module):
|
97 |
-
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
|
98 |
-
num_channels, fuse_method, multi_scale_output=True):
|
99 |
-
super(HighResolutionModule, self).__init__()
|
100 |
-
self._check_branches(
|
101 |
-
num_branches, blocks, num_blocks, num_inchannels, num_channels)
|
102 |
-
|
103 |
-
self.num_inchannels = num_inchannels
|
104 |
-
self.fuse_method = fuse_method
|
105 |
-
self.num_branches = num_branches
|
106 |
-
|
107 |
-
self.multi_scale_output = multi_scale_output
|
108 |
-
|
109 |
-
self.branches = self._make_branches(
|
110 |
-
num_branches, blocks, num_blocks, num_channels)
|
111 |
-
self.fuse_layers = self._make_fuse_layers()
|
112 |
-
self.relu = nn.ReLU(True)
|
113 |
-
|
114 |
-
def _check_branches(self, num_branches, blocks, num_blocks,
|
115 |
-
num_inchannels, num_channels):
|
116 |
-
if num_branches != len(num_blocks):
|
117 |
-
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
118 |
-
num_branches, len(num_blocks))
|
119 |
-
logger.error(error_msg)
|
120 |
-
raise ValueError(error_msg)
|
121 |
-
|
122 |
-
if num_branches != len(num_channels):
|
123 |
-
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
124 |
-
num_branches, len(num_channels))
|
125 |
-
logger.error(error_msg)
|
126 |
-
raise ValueError(error_msg)
|
127 |
-
|
128 |
-
if num_branches != len(num_inchannels):
|
129 |
-
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
130 |
-
num_branches, len(num_inchannels))
|
131 |
-
logger.error(error_msg)
|
132 |
-
raise ValueError(error_msg)
|
133 |
-
|
134 |
-
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
135 |
-
stride=1):
|
136 |
-
downsample = None
|
137 |
-
if stride != 1 or \
|
138 |
-
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
|
139 |
-
downsample = nn.Sequential(
|
140 |
-
nn.Conv2d(
|
141 |
-
self.num_inchannels[branch_index],
|
142 |
-
num_channels[branch_index] * block.expansion,
|
143 |
-
kernel_size=1, stride=stride, bias=False
|
144 |
-
),
|
145 |
-
nn.BatchNorm2d(
|
146 |
-
num_channels[branch_index] * block.expansion,
|
147 |
-
momentum=BN_MOMENTUM
|
148 |
-
),
|
149 |
-
)
|
150 |
-
|
151 |
-
layers = []
|
152 |
-
layers.append(
|
153 |
-
block(
|
154 |
-
self.num_inchannels[branch_index],
|
155 |
-
num_channels[branch_index],
|
156 |
-
stride,
|
157 |
-
downsample
|
158 |
-
)
|
159 |
-
)
|
160 |
-
self.num_inchannels[branch_index] = \
|
161 |
-
num_channels[branch_index] * block.expansion
|
162 |
-
for i in range(1, num_blocks[branch_index]):
|
163 |
-
layers.append(
|
164 |
-
block(
|
165 |
-
self.num_inchannels[branch_index],
|
166 |
-
num_channels[branch_index]
|
167 |
-
)
|
168 |
-
)
|
169 |
-
|
170 |
-
return nn.Sequential(*layers)
|
171 |
-
|
172 |
-
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
173 |
-
branches = []
|
174 |
-
|
175 |
-
for i in range(num_branches):
|
176 |
-
branches.append(
|
177 |
-
self._make_one_branch(i, block, num_blocks, num_channels)
|
178 |
-
)
|
179 |
-
|
180 |
-
return nn.ModuleList(branches)
|
181 |
-
|
182 |
-
def _make_fuse_layers(self):
|
183 |
-
if self.num_branches == 1:
|
184 |
-
return None
|
185 |
-
|
186 |
-
num_branches = self.num_branches
|
187 |
-
num_inchannels = self.num_inchannels
|
188 |
-
fuse_layers = []
|
189 |
-
for i in range(num_branches if self.multi_scale_output else 1):
|
190 |
-
fuse_layer = []
|
191 |
-
for j in range(num_branches):
|
192 |
-
if j > i:
|
193 |
-
fuse_layer.append(
|
194 |
-
nn.Sequential(
|
195 |
-
nn.Conv2d(
|
196 |
-
num_inchannels[j],
|
197 |
-
num_inchannels[i],
|
198 |
-
1, 1, 0, bias=False
|
199 |
-
),
|
200 |
-
nn.BatchNorm2d(num_inchannels[i]),
|
201 |
-
nn.Upsample(scale_factor=2**(j-i), mode='nearest')
|
202 |
-
)
|
203 |
-
)
|
204 |
-
elif j == i:
|
205 |
-
fuse_layer.append(None)
|
206 |
-
else:
|
207 |
-
conv3x3s = []
|
208 |
-
for k in range(i-j):
|
209 |
-
if k == i - j - 1:
|
210 |
-
num_outchannels_conv3x3 = num_inchannels[i]
|
211 |
-
conv3x3s.append(
|
212 |
-
nn.Sequential(
|
213 |
-
nn.Conv2d(
|
214 |
-
num_inchannels[j],
|
215 |
-
num_outchannels_conv3x3,
|
216 |
-
3, 2, 1, bias=False
|
217 |
-
),
|
218 |
-
nn.BatchNorm2d(num_outchannels_conv3x3)
|
219 |
-
)
|
220 |
-
)
|
221 |
-
else:
|
222 |
-
num_outchannels_conv3x3 = num_inchannels[j]
|
223 |
-
conv3x3s.append(
|
224 |
-
nn.Sequential(
|
225 |
-
nn.Conv2d(
|
226 |
-
num_inchannels[j],
|
227 |
-
num_outchannels_conv3x3,
|
228 |
-
3, 2, 1, bias=False
|
229 |
-
),
|
230 |
-
nn.BatchNorm2d(num_outchannels_conv3x3),
|
231 |
-
nn.ReLU(True)
|
232 |
-
)
|
233 |
-
)
|
234 |
-
fuse_layer.append(nn.Sequential(*conv3x3s))
|
235 |
-
fuse_layers.append(nn.ModuleList(fuse_layer))
|
236 |
-
|
237 |
-
return nn.ModuleList(fuse_layers)
|
238 |
-
|
239 |
-
def get_num_inchannels(self):
|
240 |
-
return self.num_inchannels
|
241 |
-
|
242 |
-
def forward(self, x):
|
243 |
-
if self.num_branches == 1:
|
244 |
-
return [self.branches[0](x[0])]
|
245 |
-
|
246 |
-
for i in range(self.num_branches):
|
247 |
-
x[i] = self.branches[i](x[i])
|
248 |
-
|
249 |
-
x_fuse = []
|
250 |
-
|
251 |
-
for i in range(len(self.fuse_layers)):
|
252 |
-
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
253 |
-
for j in range(1, self.num_branches):
|
254 |
-
if i == j:
|
255 |
-
y = y + x[j]
|
256 |
-
else:
|
257 |
-
y = y + self.fuse_layers[i][j](x[j])
|
258 |
-
x_fuse.append(self.relu(y))
|
259 |
-
|
260 |
-
return x_fuse
|
261 |
-
|
262 |
-
|
263 |
-
blocks_dict = {
|
264 |
-
'BASIC': BasicBlock,
|
265 |
-
'BOTTLENECK': Bottleneck
|
266 |
-
}
|
267 |
-
|
268 |
-
|
269 |
-
class PoseHighResolutionNet(nn.Module):
|
270 |
-
|
271 |
-
def __init__(self, cfg):
|
272 |
-
self.inplanes = 64
|
273 |
-
extra = cfg['MODEL']['EXTRA']
|
274 |
-
super(PoseHighResolutionNet, self).__init__()
|
275 |
-
|
276 |
-
self.cfg = extra
|
277 |
-
|
278 |
-
# stem net
|
279 |
-
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
|
280 |
-
bias=False)
|
281 |
-
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
282 |
-
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
|
283 |
-
bias=False)
|
284 |
-
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
285 |
-
self.relu = nn.ReLU(inplace=True)
|
286 |
-
self.layer1 = self._make_layer(Bottleneck, 64, 4)
|
287 |
-
|
288 |
-
self.stage2_cfg = extra['STAGE2']
|
289 |
-
num_channels = self.stage2_cfg['NUM_CHANNELS']
|
290 |
-
block = blocks_dict[self.stage2_cfg['BLOCK']]
|
291 |
-
num_channels = [
|
292 |
-
num_channels[i] * block.expansion for i in range(len(num_channels))
|
293 |
-
]
|
294 |
-
self.transition1 = self._make_transition_layer([256], num_channels)
|
295 |
-
self.stage2, pre_stage_channels = self._make_stage(
|
296 |
-
self.stage2_cfg, num_channels)
|
297 |
-
|
298 |
-
self.stage3_cfg = extra['STAGE3']
|
299 |
-
num_channels = self.stage3_cfg['NUM_CHANNELS']
|
300 |
-
block = blocks_dict[self.stage3_cfg['BLOCK']]
|
301 |
-
num_channels = [
|
302 |
-
num_channels[i] * block.expansion for i in range(len(num_channels))
|
303 |
-
]
|
304 |
-
self.transition2 = self._make_transition_layer(
|
305 |
-
pre_stage_channels, num_channels)
|
306 |
-
self.stage3, pre_stage_channels = self._make_stage(
|
307 |
-
self.stage3_cfg, num_channels)
|
308 |
-
|
309 |
-
self.stage4_cfg = extra['STAGE4']
|
310 |
-
num_channels = self.stage4_cfg['NUM_CHANNELS']
|
311 |
-
block = blocks_dict[self.stage4_cfg['BLOCK']]
|
312 |
-
num_channels = [
|
313 |
-
num_channels[i] * block.expansion for i in range(len(num_channels))
|
314 |
-
]
|
315 |
-
self.transition3 = self._make_transition_layer(
|
316 |
-
pre_stage_channels, num_channels)
|
317 |
-
self.stage4, pre_stage_channels = self._make_stage(
|
318 |
-
self.stage4_cfg, num_channels, multi_scale_output=True)
|
319 |
-
|
320 |
-
self.final_layer = nn.Conv2d(
|
321 |
-
in_channels=pre_stage_channels[0],
|
322 |
-
out_channels=cfg['MODEL']['NUM_JOINTS'],
|
323 |
-
kernel_size=extra['FINAL_CONV_KERNEL'],
|
324 |
-
stride=1,
|
325 |
-
padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0
|
326 |
-
)
|
327 |
-
|
328 |
-
self.pretrained_layers = extra['PRETRAINED_LAYERS']
|
329 |
-
|
330 |
-
if extra.DOWNSAMPLE and extra.USE_CONV:
|
331 |
-
self.downsample_stage_1 = self._make_downsample_layer(3, num_channel=self.stage2_cfg['NUM_CHANNELS'][0])
|
332 |
-
self.downsample_stage_2 = self._make_downsample_layer(2, num_channel=self.stage2_cfg['NUM_CHANNELS'][-1])
|
333 |
-
self.downsample_stage_3 = self._make_downsample_layer(1, num_channel=self.stage3_cfg['NUM_CHANNELS'][-1])
|
334 |
-
elif not extra.DOWNSAMPLE and extra.USE_CONV:
|
335 |
-
self.upsample_stage_2 = self._make_upsample_layer(1, num_channel=self.stage2_cfg['NUM_CHANNELS'][-1])
|
336 |
-
self.upsample_stage_3 = self._make_upsample_layer(2, num_channel=self.stage3_cfg['NUM_CHANNELS'][-1])
|
337 |
-
self.upsample_stage_4 = self._make_upsample_layer(3, num_channel=self.stage4_cfg['NUM_CHANNELS'][-1])
|
338 |
-
|
339 |
-
def _make_transition_layer(
|
340 |
-
self, num_channels_pre_layer, num_channels_cur_layer):
|
341 |
-
num_branches_cur = len(num_channels_cur_layer)
|
342 |
-
num_branches_pre = len(num_channels_pre_layer)
|
343 |
-
|
344 |
-
transition_layers = []
|
345 |
-
for i in range(num_branches_cur):
|
346 |
-
if i < num_branches_pre:
|
347 |
-
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
348 |
-
transition_layers.append(
|
349 |
-
nn.Sequential(
|
350 |
-
nn.Conv2d(
|
351 |
-
num_channels_pre_layer[i],
|
352 |
-
num_channels_cur_layer[i],
|
353 |
-
3, 1, 1, bias=False
|
354 |
-
),
|
355 |
-
nn.BatchNorm2d(num_channels_cur_layer[i]),
|
356 |
-
nn.ReLU(inplace=True)
|
357 |
-
)
|
358 |
-
)
|
359 |
-
else:
|
360 |
-
transition_layers.append(None)
|
361 |
-
else:
|
362 |
-
conv3x3s = []
|
363 |
-
for j in range(i+1-num_branches_pre):
|
364 |
-
inchannels = num_channels_pre_layer[-1]
|
365 |
-
outchannels = num_channels_cur_layer[i] \
|
366 |
-
if j == i-num_branches_pre else inchannels
|
367 |
-
conv3x3s.append(
|
368 |
-
nn.Sequential(
|
369 |
-
nn.Conv2d(
|
370 |
-
inchannels, outchannels, 3, 2, 1, bias=False
|
371 |
-
),
|
372 |
-
nn.BatchNorm2d(outchannels),
|
373 |
-
nn.ReLU(inplace=True)
|
374 |
-
)
|
375 |
-
)
|
376 |
-
transition_layers.append(nn.Sequential(*conv3x3s))
|
377 |
-
|
378 |
-
return nn.ModuleList(transition_layers)
|
379 |
-
|
380 |
-
def _make_layer(self, block, planes, blocks, stride=1):
|
381 |
-
downsample = None
|
382 |
-
if stride != 1 or self.inplanes != planes * block.expansion:
|
383 |
-
downsample = nn.Sequential(
|
384 |
-
nn.Conv2d(
|
385 |
-
self.inplanes, planes * block.expansion,
|
386 |
-
kernel_size=1, stride=stride, bias=False
|
387 |
-
),
|
388 |
-
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
|
389 |
-
)
|
390 |
-
|
391 |
-
layers = []
|
392 |
-
layers.append(block(self.inplanes, planes, stride, downsample))
|
393 |
-
self.inplanes = planes * block.expansion
|
394 |
-
for i in range(1, blocks):
|
395 |
-
layers.append(block(self.inplanes, planes))
|
396 |
-
|
397 |
-
return nn.Sequential(*layers)
|
398 |
-
|
399 |
-
def _make_stage(self, layer_config, num_inchannels,
|
400 |
-
multi_scale_output=True):
|
401 |
-
num_modules = layer_config['NUM_MODULES']
|
402 |
-
num_branches = layer_config['NUM_BRANCHES']
|
403 |
-
num_blocks = layer_config['NUM_BLOCKS']
|
404 |
-
num_channels = layer_config['NUM_CHANNELS']
|
405 |
-
block = blocks_dict[layer_config['BLOCK']]
|
406 |
-
fuse_method = layer_config['FUSE_METHOD']
|
407 |
-
|
408 |
-
modules = []
|
409 |
-
for i in range(num_modules):
|
410 |
-
# multi_scale_output is only used last module
|
411 |
-
if not multi_scale_output and i == num_modules - 1:
|
412 |
-
reset_multi_scale_output = False
|
413 |
-
else:
|
414 |
-
reset_multi_scale_output = True
|
415 |
-
|
416 |
-
modules.append(
|
417 |
-
HighResolutionModule(
|
418 |
-
num_branches,
|
419 |
-
block,
|
420 |
-
num_blocks,
|
421 |
-
num_inchannels,
|
422 |
-
num_channels,
|
423 |
-
fuse_method,
|
424 |
-
reset_multi_scale_output
|
425 |
-
)
|
426 |
-
)
|
427 |
-
num_inchannels = modules[-1].get_num_inchannels()
|
428 |
-
|
429 |
-
return nn.Sequential(*modules), num_inchannels
|
430 |
-
|
431 |
-
def _make_upsample_layer(self, num_layers, num_channel, kernel_size=3):
|
432 |
-
layers = []
|
433 |
-
for i in range(num_layers):
|
434 |
-
layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
|
435 |
-
layers.append(
|
436 |
-
nn.Conv2d(
|
437 |
-
in_channels=num_channel, out_channels=num_channel,
|
438 |
-
kernel_size=kernel_size, stride=1, padding=1, bias=False,
|
439 |
-
)
|
440 |
-
)
|
441 |
-
layers.append(nn.BatchNorm2d(num_channel, momentum=BN_MOMENTUM))
|
442 |
-
layers.append(nn.ReLU(inplace=True))
|
443 |
-
|
444 |
-
return nn.Sequential(*layers)
|
445 |
-
|
446 |
-
def _make_downsample_layer(self, num_layers, num_channel, kernel_size=3):
|
447 |
-
layers = []
|
448 |
-
for i in range(num_layers):
|
449 |
-
layers.append(
|
450 |
-
nn.Conv2d(
|
451 |
-
in_channels=num_channel, out_channels=num_channel,
|
452 |
-
kernel_size=kernel_size, stride=2, padding=1, bias=False,
|
453 |
-
)
|
454 |
-
)
|
455 |
-
layers.append(nn.BatchNorm2d(num_channel, momentum=BN_MOMENTUM))
|
456 |
-
layers.append(nn.ReLU(inplace=True))
|
457 |
-
|
458 |
-
return nn.Sequential(*layers)
|
459 |
-
|
460 |
-
def forward(self, x):
|
461 |
-
x = self.conv1(x)
|
462 |
-
x = self.bn1(x)
|
463 |
-
x = self.relu(x)
|
464 |
-
x = self.conv2(x)
|
465 |
-
x = self.bn2(x)
|
466 |
-
x = self.relu(x)
|
467 |
-
x = self.layer1(x)
|
468 |
-
|
469 |
-
x_list = []
|
470 |
-
for i in range(self.stage2_cfg['NUM_BRANCHES']):
|
471 |
-
if self.transition1[i] is not None:
|
472 |
-
x_list.append(self.transition1[i](x))
|
473 |
-
else:
|
474 |
-
x_list.append(x)
|
475 |
-
y_list = self.stage2(x_list)
|
476 |
-
|
477 |
-
x_list = []
|
478 |
-
for i in range(self.stage3_cfg['NUM_BRANCHES']):
|
479 |
-
if self.transition2[i] is not None:
|
480 |
-
x_list.append(self.transition2[i](y_list[-1]))
|
481 |
-
else:
|
482 |
-
x_list.append(y_list[i])
|
483 |
-
y_list = self.stage3(x_list)
|
484 |
-
|
485 |
-
x_list = []
|
486 |
-
for i in range(self.stage4_cfg['NUM_BRANCHES']):
|
487 |
-
if self.transition3[i] is not None:
|
488 |
-
x_list.append(self.transition3[i](y_list[-1]))
|
489 |
-
else:
|
490 |
-
x_list.append(y_list[i])
|
491 |
-
x = self.stage4(x_list)
|
492 |
-
|
493 |
-
if self.cfg.DOWNSAMPLE:
|
494 |
-
if self.cfg.USE_CONV:
|
495 |
-
# Downsampling with strided convolutions
|
496 |
-
x1 = self.downsample_stage_1(x[0])
|
497 |
-
x2 = self.downsample_stage_2(x[1])
|
498 |
-
x3 = self.downsample_stage_3(x[2])
|
499 |
-
x = torch.cat([x1, x2, x3, x[3]], 1)
|
500 |
-
else:
|
501 |
-
# Downsampling with interpolation
|
502 |
-
x0_h, x0_w = x[3].size(2), x[3].size(3)
|
503 |
-
x1 = F.interpolate(x[0], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
504 |
-
x2 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
505 |
-
x3 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
506 |
-
x = torch.cat([x1, x2, x3, x[3]], 1)
|
507 |
-
else:
|
508 |
-
if self.cfg.USE_CONV:
|
509 |
-
# Upsampling with interpolations + convolutions
|
510 |
-
x1 = self.upsample_stage_2(x[1])
|
511 |
-
x2 = self.upsample_stage_3(x[2])
|
512 |
-
x3 = self.upsample_stage_4(x[3])
|
513 |
-
x = torch.cat([x[0], x1, x2, x3], 1)
|
514 |
-
else:
|
515 |
-
# Upsampling with interpolation
|
516 |
-
x0_h, x0_w = x[0].size(2), x[0].size(3)
|
517 |
-
x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
518 |
-
x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
519 |
-
x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
520 |
-
x = torch.cat([x[0], x1, x2, x3], 1)
|
521 |
-
|
522 |
-
return x
|
523 |
-
|
524 |
-
def init_weights(self, pretrained=''):
|
525 |
-
logger.info('=> init weights from normal distribution')
|
526 |
-
for m in self.modules():
|
527 |
-
if isinstance(m, nn.Conv2d):
|
528 |
-
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
529 |
-
nn.init.normal_(m.weight, std=0.001)
|
530 |
-
for name, _ in m.named_parameters():
|
531 |
-
if name in ['bias']:
|
532 |
-
nn.init.constant_(m.bias, 0)
|
533 |
-
elif isinstance(m, nn.BatchNorm2d):
|
534 |
-
nn.init.constant_(m.weight, 1)
|
535 |
-
nn.init.constant_(m.bias, 0)
|
536 |
-
elif isinstance(m, nn.ConvTranspose2d):
|
537 |
-
nn.init.normal_(m.weight, std=0.001)
|
538 |
-
for name, _ in m.named_parameters():
|
539 |
-
if name in ['bias']:
|
540 |
-
nn.init.constant_(m.bias, 0)
|
541 |
-
|
542 |
-
if os.path.isfile(pretrained):
|
543 |
-
pretrained_state_dict = torch.load(pretrained)
|
544 |
-
logger.info('=> loading pretrained model {}'.format(pretrained))
|
545 |
-
|
546 |
-
need_init_state_dict = {}
|
547 |
-
for name, m in pretrained_state_dict.items():
|
548 |
-
if name.split('.')[0] in self.pretrained_layers \
|
549 |
-
or self.pretrained_layers[0] is '*':
|
550 |
-
need_init_state_dict[name] = m
|
551 |
-
self.load_state_dict(need_init_state_dict, strict=False)
|
552 |
-
elif pretrained:
|
553 |
-
logger.warning('IMPORTANT WARNING!! Please download pre-trained models if you are in TRAINING mode!')
|
554 |
-
# raise ValueError('{} is not exist!'.format(pretrained))
|
555 |
-
|
556 |
-
|
557 |
-
def get_pose_net(cfg, is_train):
|
558 |
-
model = PoseHighResolutionNet(cfg)
|
559 |
-
|
560 |
-
if is_train and cfg['MODEL']['INIT_WEIGHTS']:
|
561 |
-
model.init_weights(cfg['MODEL']['PRETRAINED'])
|
562 |
-
|
563 |
-
return model
|
564 |
-
|
565 |
-
|
566 |
-
def get_cfg_defaults(pretrained, width=32, downsample=False, use_conv=False):
|
567 |
-
# pose_multi_resoluton_net related params
|
568 |
-
HRNET = CN()
|
569 |
-
HRNET.PRETRAINED_LAYERS = [
|
570 |
-
'conv1', 'bn1', 'conv2', 'bn2', 'layer1', 'transition1',
|
571 |
-
'stage2', 'transition2', 'stage3', 'transition3', 'stage4',
|
572 |
-
]
|
573 |
-
HRNET.STEM_INPLANES = 64
|
574 |
-
HRNET.FINAL_CONV_KERNEL = 1
|
575 |
-
HRNET.STAGE2 = CN()
|
576 |
-
HRNET.STAGE2.NUM_MODULES = 1
|
577 |
-
HRNET.STAGE2.NUM_BRANCHES = 2
|
578 |
-
HRNET.STAGE2.NUM_BLOCKS = [4, 4]
|
579 |
-
HRNET.STAGE2.NUM_CHANNELS = [width, width*2]
|
580 |
-
HRNET.STAGE2.BLOCK = 'BASIC'
|
581 |
-
HRNET.STAGE2.FUSE_METHOD = 'SUM'
|
582 |
-
HRNET.STAGE3 = CN()
|
583 |
-
HRNET.STAGE3.NUM_MODULES = 4
|
584 |
-
HRNET.STAGE3.NUM_BRANCHES = 3
|
585 |
-
HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
|
586 |
-
HRNET.STAGE3.NUM_CHANNELS = [width, width*2, width*4]
|
587 |
-
HRNET.STAGE3.BLOCK = 'BASIC'
|
588 |
-
HRNET.STAGE3.FUSE_METHOD = 'SUM'
|
589 |
-
HRNET.STAGE4 = CN()
|
590 |
-
HRNET.STAGE4.NUM_MODULES = 3
|
591 |
-
HRNET.STAGE4.NUM_BRANCHES = 4
|
592 |
-
HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
|
593 |
-
HRNET.STAGE4.NUM_CHANNELS = [width, width*2, width*4, width*8]
|
594 |
-
HRNET.STAGE4.BLOCK = 'BASIC'
|
595 |
-
HRNET.STAGE4.FUSE_METHOD = 'SUM'
|
596 |
-
HRNET.DOWNSAMPLE = downsample
|
597 |
-
HRNET.USE_CONV = use_conv
|
598 |
-
|
599 |
-
cfg = CN()
|
600 |
-
cfg.MODEL = CN()
|
601 |
-
cfg.MODEL.INIT_WEIGHTS = True
|
602 |
-
cfg.MODEL.PRETRAINED = pretrained # 'data/pretrained_models/hrnet_w32-36af842e.pth'
|
603 |
-
cfg.MODEL.EXTRA = HRNET
|
604 |
-
cfg.MODEL.NUM_JOINTS = 24
|
605 |
-
return cfg
|
606 |
-
|
607 |
-
|
608 |
-
def hrnet_w32(
|
609 |
-
pretrained=True,
|
610 |
-
pretrained_ckpt='data/weights/pose_hrnet_w32_256x192.pth',
|
611 |
-
downsample=False,
|
612 |
-
use_conv=False,
|
613 |
-
):
|
614 |
-
cfg = get_cfg_defaults(pretrained_ckpt, width=32, downsample=downsample, use_conv=use_conv)
|
615 |
-
return get_pose_net(cfg, is_train=True)
|
616 |
-
|
617 |
-
|
618 |
-
def hrnet_w48(
|
619 |
-
pretrained=True,
|
620 |
-
pretrained_ckpt='data/weights/pose_hrnet_w48_256x192.pth',
|
621 |
-
downsample=False,
|
622 |
-
use_conv=False,
|
623 |
-
):
|
624 |
-
cfg = get_cfg_defaults(pretrained_ckpt, width=48, downsample=downsample, use_conv=use_conv)
|
625 |
-
return get_pose_net(cfg, is_train=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/image_utils.py
DELETED
@@ -1,444 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
This file contains functions that are used to perform data augmentation.
|
3 |
-
"""
|
4 |
-
import cv2
|
5 |
-
import torch
|
6 |
-
import json
|
7 |
-
from skimage.transform import rotate, resize
|
8 |
-
import numpy as np
|
9 |
-
import jpeg4py as jpeg
|
10 |
-
from trimesh.visual import color
|
11 |
-
|
12 |
-
# from ..core import constants
|
13 |
-
# from .vibe_image_utils import gen_trans_from_patch_cv
|
14 |
-
from .kp_utils import map_smpl_to_common, get_smpl_joint_names
|
15 |
-
|
16 |
-
def get_transform(center, scale, res, rot=0):
|
17 |
-
"""Generate transformation matrix."""
|
18 |
-
h = 200 * scale
|
19 |
-
t = np.zeros((3, 3))
|
20 |
-
t[0, 0] = float(res[1]) / h
|
21 |
-
t[1, 1] = float(res[0]) / h
|
22 |
-
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
|
23 |
-
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
|
24 |
-
t[2, 2] = 1
|
25 |
-
if not rot == 0:
|
26 |
-
rot = -rot # To match direction of rotation from cropping
|
27 |
-
rot_mat = np.zeros((3, 3))
|
28 |
-
rot_rad = rot * np.pi / 180
|
29 |
-
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
30 |
-
rot_mat[0, :2] = [cs, -sn]
|
31 |
-
rot_mat[1, :2] = [sn, cs]
|
32 |
-
rot_mat[2, 2] = 1
|
33 |
-
# Need to rotate around center
|
34 |
-
t_mat = np.eye(3)
|
35 |
-
t_mat[0, 2] = -res[1] / 2
|
36 |
-
t_mat[1, 2] = -res[0] / 2
|
37 |
-
t_inv = t_mat.copy()
|
38 |
-
t_inv[:2, 2] *= -1
|
39 |
-
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
|
40 |
-
return t
|
41 |
-
|
42 |
-
|
43 |
-
def transform(pt, center, scale, res, invert=0, rot=0):
|
44 |
-
"""Transform pixel location to different reference."""
|
45 |
-
t = get_transform(center, scale, res, rot=rot)
|
46 |
-
if invert:
|
47 |
-
t = np.linalg.inv(t)
|
48 |
-
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
|
49 |
-
new_pt = np.dot(t, new_pt)
|
50 |
-
return new_pt[:2].astype(int) + 1
|
51 |
-
|
52 |
-
|
53 |
-
def crop(img, center, scale, res, rot=0):
|
54 |
-
"""Crop image according to the supplied bounding box."""
|
55 |
-
# Upper left point
|
56 |
-
ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
|
57 |
-
# Bottom right point
|
58 |
-
br = np.array(transform([res[0] + 1,
|
59 |
-
res[1] + 1], center, scale, res, invert=1)) - 1
|
60 |
-
|
61 |
-
# Padding so that when rotated proper amount of context is included
|
62 |
-
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
|
63 |
-
if not rot == 0:
|
64 |
-
ul -= pad
|
65 |
-
br += pad
|
66 |
-
|
67 |
-
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
68 |
-
if len(img.shape) > 2:
|
69 |
-
new_shape += [img.shape[2]]
|
70 |
-
new_img = np.zeros(new_shape)
|
71 |
-
|
72 |
-
# Range to fill new array
|
73 |
-
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
|
74 |
-
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
|
75 |
-
# Range to sample from original image
|
76 |
-
old_x = max(0, ul[0]), min(len(img[0]), br[0])
|
77 |
-
old_y = max(0, ul[1]), min(len(img), br[1])
|
78 |
-
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
|
79 |
-
old_x[0]:old_x[1]]
|
80 |
-
|
81 |
-
if not rot == 0:
|
82 |
-
# Remove padding
|
83 |
-
|
84 |
-
new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
|
85 |
-
new_img = new_img[pad:-pad, pad:-pad]
|
86 |
-
|
87 |
-
# resize image
|
88 |
-
new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
|
89 |
-
return new_img
|
90 |
-
|
91 |
-
|
92 |
-
def crop_cv2(img, center, scale, res, rot=0):
|
93 |
-
c_x, c_y = center
|
94 |
-
c_x, c_y = int(round(c_x)), int(round(c_y))
|
95 |
-
patch_width, patch_height = int(round(res[0])), int(round(res[1]))
|
96 |
-
bb_width = bb_height = int(round(scale * 200.))
|
97 |
-
|
98 |
-
trans = gen_trans_from_patch_cv(
|
99 |
-
c_x, c_y, bb_width, bb_height,
|
100 |
-
patch_width, patch_height,
|
101 |
-
scale=1.0, rot=rot, inv=False,
|
102 |
-
)
|
103 |
-
|
104 |
-
crop_img = cv2.warpAffine(
|
105 |
-
img, trans, (int(patch_width), int(patch_height)),
|
106 |
-
flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT
|
107 |
-
)
|
108 |
-
|
109 |
-
return crop_img
|
110 |
-
|
111 |
-
|
112 |
-
def get_random_crop_coords(height, width, crop_height, crop_width, h_start, w_start):
|
113 |
-
y1 = int((height - crop_height) * h_start)
|
114 |
-
y2 = y1 + crop_height
|
115 |
-
x1 = int((width - crop_width) * w_start)
|
116 |
-
x2 = x1 + crop_width
|
117 |
-
return x1, y1, x2, y2
|
118 |
-
|
119 |
-
|
120 |
-
def random_crop(center, scale, crop_scale_factor, axis='all'):
|
121 |
-
'''
|
122 |
-
center: bbox center [x,y]
|
123 |
-
scale: bbox height / 200
|
124 |
-
crop_scale_factor: amount of cropping to be applied
|
125 |
-
axis: axis which cropping will be applied
|
126 |
-
"x": center the y axis and get random crops in x
|
127 |
-
"y": center the x axis and get random crops in y
|
128 |
-
"all": randomly crop from all locations
|
129 |
-
'''
|
130 |
-
orig_size = int(scale * 200.)
|
131 |
-
ul = (center - (orig_size / 2.)).astype(int)
|
132 |
-
|
133 |
-
crop_size = int(orig_size * crop_scale_factor)
|
134 |
-
|
135 |
-
if axis == 'all':
|
136 |
-
h_start = np.random.rand()
|
137 |
-
w_start = np.random.rand()
|
138 |
-
elif axis == 'x':
|
139 |
-
h_start = np.random.rand()
|
140 |
-
w_start = 0.5
|
141 |
-
elif axis == 'y':
|
142 |
-
h_start = 0.5
|
143 |
-
w_start = np.random.rand()
|
144 |
-
else:
|
145 |
-
raise ValueError(f'axis {axis} is undefined!')
|
146 |
-
|
147 |
-
x1, y1, x2, y2 = get_random_crop_coords(
|
148 |
-
height=orig_size,
|
149 |
-
width=orig_size,
|
150 |
-
crop_height=crop_size,
|
151 |
-
crop_width=crop_size,
|
152 |
-
h_start=h_start,
|
153 |
-
w_start=w_start,
|
154 |
-
)
|
155 |
-
scale = (y2 - y1) / 200.
|
156 |
-
center = ul + np.array([(y1 + y2) / 2, (x1 + x2) / 2])
|
157 |
-
return center, scale
|
158 |
-
|
159 |
-
|
160 |
-
def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
|
161 |
-
"""'Undo' the image cropping/resizing.
|
162 |
-
This function is used when evaluating mask/part segmentation.
|
163 |
-
"""
|
164 |
-
res = img.shape[:2]
|
165 |
-
# Upper left point
|
166 |
-
ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
|
167 |
-
# Bottom right point
|
168 |
-
br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1
|
169 |
-
# size of cropped image
|
170 |
-
crop_shape = [br[1] - ul[1], br[0] - ul[0]]
|
171 |
-
|
172 |
-
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
173 |
-
if len(img.shape) > 2:
|
174 |
-
new_shape += [img.shape[2]]
|
175 |
-
new_img = np.zeros(orig_shape, dtype=np.uint8)
|
176 |
-
# Range to fill new array
|
177 |
-
new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
|
178 |
-
new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
|
179 |
-
# Range to sample from original image
|
180 |
-
old_x = max(0, ul[0]), min(orig_shape[1], br[0])
|
181 |
-
old_y = max(0, ul[1]), min(orig_shape[0], br[1])
|
182 |
-
img = resize(img, crop_shape) #, interp='nearest') # scipy.misc.imresize(img, crop_shape, interp='nearest')
|
183 |
-
new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
|
184 |
-
return new_img
|
185 |
-
|
186 |
-
|
187 |
-
def rot_aa(aa, rot):
|
188 |
-
"""Rotate axis angle parameters."""
|
189 |
-
# pose parameters
|
190 |
-
R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
191 |
-
[np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
|
192 |
-
[0, 0, 1]])
|
193 |
-
# find the rotation of the body in camera frame
|
194 |
-
per_rdg, _ = cv2.Rodrigues(aa)
|
195 |
-
# apply the global rotation to the global orientation
|
196 |
-
resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
|
197 |
-
aa = (resrot.T)[0]
|
198 |
-
return aa
|
199 |
-
|
200 |
-
|
201 |
-
def flip_img(img):
|
202 |
-
"""Flip rgb images or masks.
|
203 |
-
channels come last, e.g. (256,256,3).
|
204 |
-
"""
|
205 |
-
img = np.fliplr(img)
|
206 |
-
return img
|
207 |
-
|
208 |
-
|
209 |
-
def flip_kp(kp):
|
210 |
-
"""Flip keypoints."""
|
211 |
-
if len(kp) == 24:
|
212 |
-
flipped_parts = constants.J24_FLIP_PERM
|
213 |
-
elif len(kp) == 49:
|
214 |
-
flipped_parts = constants.J49_FLIP_PERM
|
215 |
-
kp = kp[flipped_parts]
|
216 |
-
kp[:, 0] = - kp[:, 0]
|
217 |
-
return kp
|
218 |
-
|
219 |
-
|
220 |
-
def flip_pose(pose):
|
221 |
-
"""Flip pose.
|
222 |
-
The flipping is based on SMPL parameters.
|
223 |
-
"""
|
224 |
-
flipped_parts = constants.SMPL_POSE_FLIP_PERM
|
225 |
-
pose = pose[flipped_parts]
|
226 |
-
# we also negate the second and the third dimension of the axis-angle
|
227 |
-
pose[1::3] = -pose[1::3]
|
228 |
-
pose[2::3] = -pose[2::3]
|
229 |
-
return pose
|
230 |
-
|
231 |
-
|
232 |
-
def denormalize_images(images):
|
233 |
-
images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1)
|
234 |
-
images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1)
|
235 |
-
return images
|
236 |
-
|
237 |
-
|
238 |
-
def read_img(img_fn):
|
239 |
-
# return pil_img.fromarray(
|
240 |
-
# cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB))
|
241 |
-
# with open(img_fn, 'rb') as f:
|
242 |
-
# img = pil_img.open(f).convert('RGB')
|
243 |
-
# return img
|
244 |
-
if img_fn.endswith('jpeg') or img_fn.endswith('jpg'):
|
245 |
-
try:
|
246 |
-
with open(img_fn, 'rb') as f:
|
247 |
-
img = np.array(jpeg.JPEG(f).decode())
|
248 |
-
except jpeg.JPEGRuntimeError:
|
249 |
-
# logger.warning('{} produced a JPEGRuntimeError', img_fn)
|
250 |
-
img = cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB)
|
251 |
-
else:
|
252 |
-
# elif img_fn.endswith('png') or img_fn.endswith('JPG') or img_fn.endswith(''):
|
253 |
-
img = cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB)
|
254 |
-
return img.astype(np.float32)
|
255 |
-
|
256 |
-
|
257 |
-
def generate_heatmaps_2d(joints, joints_vis, num_joints=24, heatmap_size=56, image_size=224, sigma=1.75):
|
258 |
-
'''
|
259 |
-
:param joints: [num_joints, 3]
|
260 |
-
:param joints_vis: [num_joints, 3]
|
261 |
-
:return: target, target_weight(1: visible, 0: invisible)
|
262 |
-
'''
|
263 |
-
target_weight = np.ones((num_joints, 1), dtype=np.float32)
|
264 |
-
target_weight[:, 0] = joints_vis[:, 0]
|
265 |
-
|
266 |
-
target = np.zeros((num_joints, heatmap_size, heatmap_size), dtype=np.float32)
|
267 |
-
|
268 |
-
tmp_size = sigma * 3
|
269 |
-
|
270 |
-
# denormalize joint into heatmap coordinates
|
271 |
-
joints = (joints + 1.) * (image_size / 2.)
|
272 |
-
|
273 |
-
for joint_id in range(num_joints):
|
274 |
-
feat_stride = image_size / heatmap_size
|
275 |
-
mu_x = int(joints[joint_id][0] / feat_stride + 0.5)
|
276 |
-
mu_y = int(joints[joint_id][1] / feat_stride + 0.5)
|
277 |
-
# Check that any part of the gaussian is in-bounds
|
278 |
-
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
279 |
-
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
280 |
-
if ul[0] >= heatmap_size or ul[1] >= heatmap_size \
|
281 |
-
or br[0] < 0 or br[1] < 0:
|
282 |
-
# If not, just return the image as is
|
283 |
-
target_weight[joint_id] = 0
|
284 |
-
continue
|
285 |
-
|
286 |
-
# # Generate gaussian
|
287 |
-
size = 2 * tmp_size + 1
|
288 |
-
x = np.arange(0, size, 1, np.float32)
|
289 |
-
y = x[:, np.newaxis]
|
290 |
-
x0 = y0 = size // 2
|
291 |
-
# The gaussian is not normalized, we want the center value to equal 1
|
292 |
-
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
293 |
-
|
294 |
-
# Usable gaussian range
|
295 |
-
g_x = max(0, -ul[0]), min(br[0], heatmap_size) - ul[0]
|
296 |
-
g_y = max(0, -ul[1]), min(br[1], heatmap_size) - ul[1]
|
297 |
-
# Image range
|
298 |
-
img_x = max(0, ul[0]), min(br[0], heatmap_size)
|
299 |
-
img_y = max(0, ul[1]), min(br[1], heatmap_size)
|
300 |
-
|
301 |
-
v = target_weight[joint_id]
|
302 |
-
if v > 0.5:
|
303 |
-
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
|
304 |
-
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
|
305 |
-
|
306 |
-
return target, target_weight
|
307 |
-
|
308 |
-
|
309 |
-
def generate_part_labels(vertices, faces, cam_t, neural_renderer, body_part_texture, K, R, part_bins):
|
310 |
-
batch_size = vertices.shape[0]
|
311 |
-
|
312 |
-
body_parts, depth, mask = neural_renderer(
|
313 |
-
vertices,
|
314 |
-
faces.expand(batch_size, -1, -1),
|
315 |
-
textures=body_part_texture.expand(batch_size, -1, -1, -1, -1, -1),
|
316 |
-
K=K.expand(batch_size, -1, -1),
|
317 |
-
R=R.expand(batch_size, -1, -1),
|
318 |
-
t=cam_t.unsqueeze(1),
|
319 |
-
)
|
320 |
-
|
321 |
-
render_rgb = body_parts.clone()
|
322 |
-
|
323 |
-
body_parts = body_parts.permute(0, 2, 3, 1)
|
324 |
-
body_parts *= 255. # multiply it with 255 to make labels distant
|
325 |
-
body_parts, _ = body_parts.max(-1) # reduce to single channel
|
326 |
-
|
327 |
-
body_parts = torch.bucketize(body_parts.detach(), part_bins, right=True) # np.digitize(body_parts, bins, right=True)
|
328 |
-
|
329 |
-
# add 1 to make background label 0
|
330 |
-
body_parts = body_parts.long() + 1
|
331 |
-
body_parts = body_parts * mask.detach()
|
332 |
-
|
333 |
-
return body_parts.long(), render_rgb
|
334 |
-
|
335 |
-
|
336 |
-
def generate_heatmaps_2d_batch(joints, num_joints=24, heatmap_size=56, image_size=224, sigma=1.75):
|
337 |
-
batch_size = joints.shape[0]
|
338 |
-
|
339 |
-
joints = joints.detach().cpu().numpy()
|
340 |
-
joints_vis = np.ones_like(joints)
|
341 |
-
|
342 |
-
heatmaps = []
|
343 |
-
heatmaps_vis = []
|
344 |
-
for i in range(batch_size):
|
345 |
-
hm, hm_vis = generate_heatmaps_2d(joints[i], joints_vis[i], num_joints, heatmap_size, image_size, sigma)
|
346 |
-
heatmaps.append(hm)
|
347 |
-
heatmaps_vis.append(hm_vis)
|
348 |
-
|
349 |
-
return torch.from_numpy(np.stack(heatmaps)).float().to('cuda'), \
|
350 |
-
torch.from_numpy(np.stack(heatmaps_vis)).float().to('cuda')
|
351 |
-
|
352 |
-
|
353 |
-
def get_body_part_texture(faces, model_type='smpl', non_parametric=False):
|
354 |
-
if model_type == 'smpl':
|
355 |
-
n_vertices = 6890
|
356 |
-
segmentation_path = 'data/smpl_vert_segmentation.json'
|
357 |
-
if model_type == 'smplx':
|
358 |
-
n_vertices = 10475
|
359 |
-
segmentation_path = 'data/smplx_vert_segmentation.json'
|
360 |
-
|
361 |
-
with open(segmentation_path, 'rb') as f:
|
362 |
-
part_segmentation = json.load(f)
|
363 |
-
|
364 |
-
# map all vertex ids to the joint ids
|
365 |
-
joint_names = get_smpl_joint_names()
|
366 |
-
smplx_extra_joint_names = ['leftEye', 'eyeballs', 'rightEye']
|
367 |
-
body_vert_idx = np.zeros((n_vertices), dtype=np.int32) - 1 # -1 for missing label
|
368 |
-
for i, (k, v) in enumerate(part_segmentation.items()):
|
369 |
-
if k in smplx_extra_joint_names and model_type == 'smplx':
|
370 |
-
k = 'head' # map all extra smplx face joints to head
|
371 |
-
body_joint_idx = joint_names.index(k)
|
372 |
-
body_vert_idx[v] = body_joint_idx
|
373 |
-
|
374 |
-
# pare implementation
|
375 |
-
# import joblib
|
376 |
-
# part_segmentation = joblib.load('data/smpl_partSegmentation_mapping.pkl')
|
377 |
-
# body_vert_idx = part_segmentation['smpl_index']
|
378 |
-
|
379 |
-
n_parts = 24.
|
380 |
-
|
381 |
-
if non_parametric:
|
382 |
-
# reduce the number of body_parts to 14
|
383 |
-
# by mapping some joints to others
|
384 |
-
n_parts = 14.
|
385 |
-
joint_mapping = map_smpl_to_common()
|
386 |
-
|
387 |
-
for jm in joint_mapping:
|
388 |
-
for j in jm[0]:
|
389 |
-
body_vert_idx[body_vert_idx==j] = jm[1]
|
390 |
-
|
391 |
-
vertex_colors = np.ones((n_vertices, 4))
|
392 |
-
vertex_colors[:, :3] = body_vert_idx[..., None]
|
393 |
-
|
394 |
-
vertex_colors = color.to_rgba(vertex_colors)
|
395 |
-
vertex_colors = vertex_colors[:, :3]/255.
|
396 |
-
|
397 |
-
face_colors = vertex_colors[faces].min(axis=1)
|
398 |
-
texture = np.zeros((1, faces.shape[0], 1, 1, 3), dtype=np.float32)
|
399 |
-
# texture[0, :, 0, 0, :] = face_colors[:, :3] / n_parts
|
400 |
-
texture[0, :, 0, 0, :] = face_colors[:, :3]
|
401 |
-
|
402 |
-
vertex_colors = torch.from_numpy(vertex_colors).float()
|
403 |
-
texture = torch.from_numpy(texture).float()
|
404 |
-
return vertex_colors, texture
|
405 |
-
|
406 |
-
|
407 |
-
def get_default_camera(focal_length, img_h, img_w, is_cam_batch=False):
|
408 |
-
if not is_cam_batch:
|
409 |
-
K = torch.eye(3)
|
410 |
-
K[0, 0] = focal_length
|
411 |
-
K[1, 1] = focal_length
|
412 |
-
K[2, 2] = 1
|
413 |
-
K[0, 2] = img_w / 2.
|
414 |
-
K[1, 2] = img_h / 2.
|
415 |
-
K = K[None, :, :]
|
416 |
-
R = torch.eye(3)[None, :, :]
|
417 |
-
else:
|
418 |
-
bs = focal_length.shape[0]
|
419 |
-
K = torch.eye(3)[None, :, :].repeat(bs, 1, 1)
|
420 |
-
K[:, 0, 0] = focal_length[:, 0]
|
421 |
-
K[:, 1, 1] = focal_length[:, 1]
|
422 |
-
K[:, 2, 2] = 1
|
423 |
-
K[:, 0, 2] = img_w / 2.
|
424 |
-
K[:, 1, 2] = img_h / 2.
|
425 |
-
R = torch.eye(3)[None, :, :].repeat(bs, 1, 1)
|
426 |
-
return K, R
|
427 |
-
|
428 |
-
|
429 |
-
def read_exif_data(img_fname):
|
430 |
-
import PIL.Image
|
431 |
-
import PIL.ExifTags
|
432 |
-
|
433 |
-
img = PIL.Image.open(img_fname)
|
434 |
-
exif_data = img._getexif()
|
435 |
-
|
436 |
-
if exif_data == None:
|
437 |
-
return None
|
438 |
-
|
439 |
-
exif = {
|
440 |
-
PIL.ExifTags.TAGS[k]: v
|
441 |
-
for k, v in exif_data.items()
|
442 |
-
if k in PIL.ExifTags.TAGS
|
443 |
-
}
|
444 |
-
return exif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/kp_utils.py
DELETED
@@ -1,1114 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
|
3 |
-
|
4 |
-
def keypoint_hflip(kp, img_width):
|
5 |
-
# Flip a keypoint horizontally around the y-axis
|
6 |
-
# kp N,2
|
7 |
-
if len(kp.shape) == 2:
|
8 |
-
kp[:,0] = (img_width - 1.) - kp[:,0]
|
9 |
-
elif len(kp.shape) == 3:
|
10 |
-
kp[:, :, 0] = (img_width - 1.) - kp[:, :, 0]
|
11 |
-
return kp
|
12 |
-
|
13 |
-
|
14 |
-
def convert_kps(joints2d, src, dst):
|
15 |
-
src_names = eval(f'get_{src}_joint_names')()
|
16 |
-
dst_names = eval(f'get_{dst}_joint_names')()
|
17 |
-
|
18 |
-
out_joints2d = np.zeros((joints2d.shape[0], len(dst_names), joints2d.shape[-1]))
|
19 |
-
|
20 |
-
for idx, jn in enumerate(dst_names):
|
21 |
-
if jn in src_names:
|
22 |
-
out_joints2d[:, idx] = joints2d[:, src_names.index(jn)]
|
23 |
-
|
24 |
-
return out_joints2d
|
25 |
-
|
26 |
-
|
27 |
-
def get_perm_idxs(src, dst):
|
28 |
-
src_names = eval(f'get_{src}_joint_names')()
|
29 |
-
dst_names = eval(f'get_{dst}_joint_names')()
|
30 |
-
idxs = [src_names.index(h) for h in dst_names if h in src_names]
|
31 |
-
return idxs
|
32 |
-
|
33 |
-
|
34 |
-
def get_mpii3d_test_joint_names():
|
35 |
-
return [
|
36 |
-
'headtop', # 'head_top',
|
37 |
-
'neck',
|
38 |
-
'rshoulder',# 'right_shoulder',
|
39 |
-
'relbow',# 'right_elbow',
|
40 |
-
'rwrist',# 'right_wrist',
|
41 |
-
'lshoulder',# 'left_shoulder',
|
42 |
-
'lelbow', # 'left_elbow',
|
43 |
-
'lwrist', # 'left_wrist',
|
44 |
-
'rhip', # 'right_hip',
|
45 |
-
'rknee', # 'right_knee',
|
46 |
-
'rankle',# 'right_ankle',
|
47 |
-
'lhip',# 'left_hip',
|
48 |
-
'lknee',# 'left_knee',
|
49 |
-
'lankle',# 'left_ankle'
|
50 |
-
'hip',# 'pelvis',
|
51 |
-
'Spine (H36M)',# 'spine',
|
52 |
-
'Head (H36M)',# 'head'
|
53 |
-
]
|
54 |
-
|
55 |
-
|
56 |
-
def get_mpii3d_joint_names():
|
57 |
-
return [
|
58 |
-
'spine3', # 0,
|
59 |
-
'spine4', # 1,
|
60 |
-
'spine2', # 2,
|
61 |
-
'Spine (H36M)', #'spine', # 3,
|
62 |
-
'hip', # 'pelvis', # 4,
|
63 |
-
'neck', # 5,
|
64 |
-
'Head (H36M)', # 'head', # 6,
|
65 |
-
"headtop", # 'head_top', # 7,
|
66 |
-
'left_clavicle', # 8,
|
67 |
-
"lshoulder", # 'left_shoulder', # 9,
|
68 |
-
"lelbow", # 'left_elbow',# 10,
|
69 |
-
"lwrist", # 'left_wrist',# 11,
|
70 |
-
'left_hand',# 12,
|
71 |
-
'right_clavicle',# 13,
|
72 |
-
'rshoulder',# 'right_shoulder',# 14,
|
73 |
-
'relbow',# 'right_elbow',# 15,
|
74 |
-
'rwrist',# 'right_wrist',# 16,
|
75 |
-
'right_hand',# 17,
|
76 |
-
'lhip', # left_hip',# 18,
|
77 |
-
'lknee', # 'left_knee',# 19,
|
78 |
-
'lankle', #left ankle # 20
|
79 |
-
'left_foot', # 21
|
80 |
-
'left_toe', # 22
|
81 |
-
"rhip", # 'right_hip',# 23
|
82 |
-
"rknee", # 'right_knee',# 24
|
83 |
-
"rankle", #'right_ankle', # 25
|
84 |
-
'right_foot',# 26
|
85 |
-
'right_toe' # 27
|
86 |
-
]
|
87 |
-
|
88 |
-
|
89 |
-
# def get_insta_joint_names():
|
90 |
-
# return [
|
91 |
-
# 'rheel' , # 0
|
92 |
-
# 'rknee' , # 1
|
93 |
-
# 'rhip' , # 2
|
94 |
-
# 'lhip' , # 3
|
95 |
-
# 'lknee' , # 4
|
96 |
-
# 'lheel' , # 5
|
97 |
-
# 'rwrist' , # 6
|
98 |
-
# 'relbow' , # 7
|
99 |
-
# 'rshoulder' , # 8
|
100 |
-
# 'lshoulder' , # 9
|
101 |
-
# 'lelbow' , # 10
|
102 |
-
# 'lwrist' , # 11
|
103 |
-
# 'neck' , # 12
|
104 |
-
# 'headtop' , # 13
|
105 |
-
# 'nose' , # 14
|
106 |
-
# 'leye' , # 15
|
107 |
-
# 'reye' , # 16
|
108 |
-
# 'lear' , # 17
|
109 |
-
# 'rear' , # 18
|
110 |
-
# 'lbigtoe' , # 19
|
111 |
-
# 'rbigtoe' , # 20
|
112 |
-
# 'lsmalltoe' , # 21
|
113 |
-
# 'rsmalltoe' , # 22
|
114 |
-
# 'lankle' , # 23
|
115 |
-
# 'rankle' , # 24
|
116 |
-
# ]
|
117 |
-
|
118 |
-
|
119 |
-
def get_insta_joint_names():
|
120 |
-
return [
|
121 |
-
'OP RHeel',
|
122 |
-
'OP RKnee',
|
123 |
-
'OP RHip',
|
124 |
-
'OP LHip',
|
125 |
-
'OP LKnee',
|
126 |
-
'OP LHeel',
|
127 |
-
'OP RWrist',
|
128 |
-
'OP RElbow',
|
129 |
-
'OP RShoulder',
|
130 |
-
'OP LShoulder',
|
131 |
-
'OP LElbow',
|
132 |
-
'OP LWrist',
|
133 |
-
'OP Neck',
|
134 |
-
'headtop',
|
135 |
-
'OP Nose',
|
136 |
-
'OP LEye',
|
137 |
-
'OP REye',
|
138 |
-
'OP LEar',
|
139 |
-
'OP REar',
|
140 |
-
'OP LBigToe',
|
141 |
-
'OP RBigToe',
|
142 |
-
'OP LSmallToe',
|
143 |
-
'OP RSmallToe',
|
144 |
-
'OP LAnkle',
|
145 |
-
'OP RAnkle',
|
146 |
-
]
|
147 |
-
|
148 |
-
|
149 |
-
def get_mmpose_joint_names():
|
150 |
-
# this naming is for the first 23 joints of MMPose
|
151 |
-
# does not include hands and face
|
152 |
-
return [
|
153 |
-
'OP Nose', # 1
|
154 |
-
'OP LEye', # 2
|
155 |
-
'OP REye', # 3
|
156 |
-
'OP LEar', # 4
|
157 |
-
'OP REar', # 5
|
158 |
-
'OP LShoulder', # 6
|
159 |
-
'OP RShoulder', # 7
|
160 |
-
'OP LElbow', # 8
|
161 |
-
'OP RElbow', # 9
|
162 |
-
'OP LWrist', # 10
|
163 |
-
'OP RWrist', # 11
|
164 |
-
'OP LHip', # 12
|
165 |
-
'OP RHip', # 13
|
166 |
-
'OP LKnee', # 14
|
167 |
-
'OP RKnee', # 15
|
168 |
-
'OP LAnkle', # 16
|
169 |
-
'OP RAnkle', # 17
|
170 |
-
'OP LBigToe', # 18
|
171 |
-
'OP LSmallToe', # 19
|
172 |
-
'OP LHeel', # 20
|
173 |
-
'OP RBigToe', # 21
|
174 |
-
'OP RSmallToe', # 22
|
175 |
-
'OP RHeel', # 23
|
176 |
-
]
|
177 |
-
|
178 |
-
|
179 |
-
def get_insta_skeleton():
|
180 |
-
return np.array(
|
181 |
-
[
|
182 |
-
[0 , 1],
|
183 |
-
[1 , 2],
|
184 |
-
[2 , 3],
|
185 |
-
[3 , 4],
|
186 |
-
[4 , 5],
|
187 |
-
[6 , 7],
|
188 |
-
[7 , 8],
|
189 |
-
[8 , 9],
|
190 |
-
[9 ,10],
|
191 |
-
[2 , 8],
|
192 |
-
[3 , 9],
|
193 |
-
[10,11],
|
194 |
-
[8 ,12],
|
195 |
-
[9 ,12],
|
196 |
-
[12,13],
|
197 |
-
[12,14],
|
198 |
-
[14,15],
|
199 |
-
[14,16],
|
200 |
-
[15,17],
|
201 |
-
[16,18],
|
202 |
-
[0 ,20],
|
203 |
-
[20,22],
|
204 |
-
[5 ,19],
|
205 |
-
[19,21],
|
206 |
-
[5 ,23],
|
207 |
-
[0 ,24],
|
208 |
-
])
|
209 |
-
|
210 |
-
|
211 |
-
def get_staf_skeleton():
|
212 |
-
return np.array(
|
213 |
-
[
|
214 |
-
[0, 1],
|
215 |
-
[1, 2],
|
216 |
-
[2, 3],
|
217 |
-
[3, 4],
|
218 |
-
[1, 5],
|
219 |
-
[5, 6],
|
220 |
-
[6, 7],
|
221 |
-
[1, 8],
|
222 |
-
[8, 9],
|
223 |
-
[9, 10],
|
224 |
-
[10, 11],
|
225 |
-
[8, 12],
|
226 |
-
[12, 13],
|
227 |
-
[13, 14],
|
228 |
-
[0, 15],
|
229 |
-
[0, 16],
|
230 |
-
[15, 17],
|
231 |
-
[16, 18],
|
232 |
-
[2, 9],
|
233 |
-
[5, 12],
|
234 |
-
[1, 19],
|
235 |
-
[20, 19],
|
236 |
-
]
|
237 |
-
)
|
238 |
-
|
239 |
-
|
240 |
-
def get_staf_joint_names():
|
241 |
-
return [
|
242 |
-
'OP Nose', # 0,
|
243 |
-
'OP Neck', # 1,
|
244 |
-
'OP RShoulder', # 2,
|
245 |
-
'OP RElbow', # 3,
|
246 |
-
'OP RWrist', # 4,
|
247 |
-
'OP LShoulder', # 5,
|
248 |
-
'OP LElbow', # 6,
|
249 |
-
'OP LWrist', # 7,
|
250 |
-
'OP MidHip', # 8,
|
251 |
-
'OP RHip', # 9,
|
252 |
-
'OP RKnee', # 10,
|
253 |
-
'OP RAnkle', # 11,
|
254 |
-
'OP LHip', # 12,
|
255 |
-
'OP LKnee', # 13,
|
256 |
-
'OP LAnkle', # 14,
|
257 |
-
'OP REye', # 15,
|
258 |
-
'OP LEye', # 16,
|
259 |
-
'OP REar', # 17,
|
260 |
-
'OP LEar', # 18,
|
261 |
-
'Neck (LSP)', # 19,
|
262 |
-
'Top of Head (LSP)', # 20,
|
263 |
-
]
|
264 |
-
|
265 |
-
|
266 |
-
def get_spin_op_joint_names():
|
267 |
-
return [
|
268 |
-
'OP Nose', # 0
|
269 |
-
'OP Neck', # 1
|
270 |
-
'OP RShoulder', # 2
|
271 |
-
'OP RElbow', # 3
|
272 |
-
'OP RWrist', # 4
|
273 |
-
'OP LShoulder', # 5
|
274 |
-
'OP LElbow', # 6
|
275 |
-
'OP LWrist', # 7
|
276 |
-
'OP MidHip', # 8
|
277 |
-
'OP RHip', # 9
|
278 |
-
'OP RKnee', # 10
|
279 |
-
'OP RAnkle', # 11
|
280 |
-
'OP LHip', # 12
|
281 |
-
'OP LKnee', # 13
|
282 |
-
'OP LAnkle', # 14
|
283 |
-
'OP REye', # 15
|
284 |
-
'OP LEye', # 16
|
285 |
-
'OP REar', # 17
|
286 |
-
'OP LEar', # 18
|
287 |
-
'OP LBigToe', # 19
|
288 |
-
'OP LSmallToe', # 20
|
289 |
-
'OP LHeel', # 21
|
290 |
-
'OP RBigToe', # 22
|
291 |
-
'OP RSmallToe', # 23
|
292 |
-
'OP RHeel', # 24
|
293 |
-
]
|
294 |
-
|
295 |
-
|
296 |
-
def get_openpose_joint_names():
|
297 |
-
return [
|
298 |
-
'OP Nose', # 0
|
299 |
-
'OP Neck', # 1
|
300 |
-
'OP RShoulder', # 2
|
301 |
-
'OP RElbow', # 3
|
302 |
-
'OP RWrist', # 4
|
303 |
-
'OP LShoulder', # 5
|
304 |
-
'OP LElbow', # 6
|
305 |
-
'OP LWrist', # 7
|
306 |
-
'OP MidHip', # 8
|
307 |
-
'OP RHip', # 9
|
308 |
-
'OP RKnee', # 10
|
309 |
-
'OP RAnkle', # 11
|
310 |
-
'OP LHip', # 12
|
311 |
-
'OP LKnee', # 13
|
312 |
-
'OP LAnkle', # 14
|
313 |
-
'OP REye', # 15
|
314 |
-
'OP LEye', # 16
|
315 |
-
'OP REar', # 17
|
316 |
-
'OP LEar', # 18
|
317 |
-
'OP LBigToe', # 19
|
318 |
-
'OP LSmallToe', # 20
|
319 |
-
'OP LHeel', # 21
|
320 |
-
'OP RBigToe', # 22
|
321 |
-
'OP RSmallToe', # 23
|
322 |
-
'OP RHeel', # 24
|
323 |
-
]
|
324 |
-
|
325 |
-
|
326 |
-
def get_spin_joint_names():
|
327 |
-
return [
|
328 |
-
'OP Nose', # 0
|
329 |
-
'OP Neck', # 1
|
330 |
-
'OP RShoulder', # 2
|
331 |
-
'OP RElbow', # 3
|
332 |
-
'OP RWrist', # 4
|
333 |
-
'OP LShoulder', # 5
|
334 |
-
'OP LElbow', # 6
|
335 |
-
'OP LWrist', # 7
|
336 |
-
'OP MidHip', # 8
|
337 |
-
'OP RHip', # 9
|
338 |
-
'OP RKnee', # 10
|
339 |
-
'OP RAnkle', # 11
|
340 |
-
'OP LHip', # 12
|
341 |
-
'OP LKnee', # 13
|
342 |
-
'OP LAnkle', # 14
|
343 |
-
'OP REye', # 15
|
344 |
-
'OP LEye', # 16
|
345 |
-
'OP REar', # 17
|
346 |
-
'OP LEar', # 18
|
347 |
-
'OP LBigToe', # 19
|
348 |
-
'OP LSmallToe', # 20
|
349 |
-
'OP LHeel', # 21
|
350 |
-
'OP RBigToe', # 22
|
351 |
-
'OP RSmallToe', # 23
|
352 |
-
'OP RHeel', # 24
|
353 |
-
'rankle', # 25
|
354 |
-
'rknee', # 26
|
355 |
-
'rhip', # 27
|
356 |
-
'lhip', # 28
|
357 |
-
'lknee', # 29
|
358 |
-
'lankle', # 30
|
359 |
-
'rwrist', # 31
|
360 |
-
'relbow', # 32
|
361 |
-
'rshoulder', # 33
|
362 |
-
'lshoulder', # 34
|
363 |
-
'lelbow', # 35
|
364 |
-
'lwrist', # 36
|
365 |
-
'neck', # 37
|
366 |
-
'headtop', # 38
|
367 |
-
'hip', # 39 'Pelvis (MPII)', # 39
|
368 |
-
'thorax', # 40 'Thorax (MPII)', # 40
|
369 |
-
'Spine (H36M)', # 41
|
370 |
-
'Jaw (H36M)', # 42
|
371 |
-
'Head (H36M)', # 43
|
372 |
-
'nose', # 44
|
373 |
-
'leye', # 45 'Left Eye', # 45
|
374 |
-
'reye', # 46 'Right Eye', # 46
|
375 |
-
'lear', # 47 'Left Ear', # 47
|
376 |
-
'rear', # 48 'Right Ear', # 48
|
377 |
-
]
|
378 |
-
|
379 |
-
def get_muco3dhp_joint_names():
|
380 |
-
return [
|
381 |
-
'headtop',
|
382 |
-
'thorax',
|
383 |
-
'rshoulder',
|
384 |
-
'relbow',
|
385 |
-
'rwrist',
|
386 |
-
'lshoulder',
|
387 |
-
'lelbow',
|
388 |
-
'lwrist',
|
389 |
-
'rhip',
|
390 |
-
'rknee',
|
391 |
-
'rankle',
|
392 |
-
'lhip',
|
393 |
-
'lknee',
|
394 |
-
'lankle',
|
395 |
-
'hip',
|
396 |
-
'Spine (H36M)',
|
397 |
-
'Head (H36M)',
|
398 |
-
'R_Hand',
|
399 |
-
'L_Hand',
|
400 |
-
'R_Toe',
|
401 |
-
'L_Toe'
|
402 |
-
]
|
403 |
-
|
404 |
-
def get_h36m_joint_names():
|
405 |
-
return [
|
406 |
-
'hip', # 0
|
407 |
-
'lhip', # 1
|
408 |
-
'lknee', # 2
|
409 |
-
'lankle', # 3
|
410 |
-
'rhip', # 4
|
411 |
-
'rknee', # 5
|
412 |
-
'rankle', # 6
|
413 |
-
'Spine (H36M)', # 7
|
414 |
-
'neck', # 8
|
415 |
-
'Head (H36M)', # 9
|
416 |
-
'headtop', # 10
|
417 |
-
'lshoulder', # 11
|
418 |
-
'lelbow', # 12
|
419 |
-
'lwrist', # 13
|
420 |
-
'rshoulder', # 14
|
421 |
-
'relbow', # 15
|
422 |
-
'rwrist', # 16
|
423 |
-
]
|
424 |
-
|
425 |
-
|
426 |
-
def get_spin_skeleton():
|
427 |
-
return np.array(
|
428 |
-
[
|
429 |
-
[0 , 1],
|
430 |
-
[1 , 2],
|
431 |
-
[2 , 3],
|
432 |
-
[3 , 4],
|
433 |
-
[1 , 5],
|
434 |
-
[5 , 6],
|
435 |
-
[6 , 7],
|
436 |
-
[1 , 8],
|
437 |
-
[8 , 9],
|
438 |
-
[9 ,10],
|
439 |
-
[10,11],
|
440 |
-
[8 ,12],
|
441 |
-
[12,13],
|
442 |
-
[13,14],
|
443 |
-
[0 ,15],
|
444 |
-
[0 ,16],
|
445 |
-
[15,17],
|
446 |
-
[16,18],
|
447 |
-
[21,19],
|
448 |
-
[19,20],
|
449 |
-
[14,21],
|
450 |
-
[11,24],
|
451 |
-
[24,22],
|
452 |
-
[22,23],
|
453 |
-
[0 ,38],
|
454 |
-
]
|
455 |
-
)
|
456 |
-
|
457 |
-
|
458 |
-
def get_openpose_skeleton():
|
459 |
-
return np.array(
|
460 |
-
[
|
461 |
-
[0 , 1],
|
462 |
-
[1 , 2],
|
463 |
-
[2 , 3],
|
464 |
-
[3 , 4],
|
465 |
-
[1 , 5],
|
466 |
-
[5 , 6],
|
467 |
-
[6 , 7],
|
468 |
-
[1 , 8],
|
469 |
-
[8 , 9],
|
470 |
-
[9 ,10],
|
471 |
-
[10,11],
|
472 |
-
[8 ,12],
|
473 |
-
[12,13],
|
474 |
-
[13,14],
|
475 |
-
[0 ,15],
|
476 |
-
[0 ,16],
|
477 |
-
[15,17],
|
478 |
-
[16,18],
|
479 |
-
[21,19],
|
480 |
-
[19,20],
|
481 |
-
[14,21],
|
482 |
-
[11,24],
|
483 |
-
[24,22],
|
484 |
-
[22,23],
|
485 |
-
]
|
486 |
-
)
|
487 |
-
|
488 |
-
|
489 |
-
def get_posetrack_joint_names():
|
490 |
-
return [
|
491 |
-
"nose",
|
492 |
-
"neck",
|
493 |
-
"headtop",
|
494 |
-
"lear",
|
495 |
-
"rear",
|
496 |
-
"lshoulder",
|
497 |
-
"rshoulder",
|
498 |
-
"lelbow",
|
499 |
-
"relbow",
|
500 |
-
"lwrist",
|
501 |
-
"rwrist",
|
502 |
-
"lhip",
|
503 |
-
"rhip",
|
504 |
-
"lknee",
|
505 |
-
"rknee",
|
506 |
-
"lankle",
|
507 |
-
"rankle"
|
508 |
-
]
|
509 |
-
|
510 |
-
|
511 |
-
def get_posetrack_original_kp_names():
|
512 |
-
return [
|
513 |
-
'nose',
|
514 |
-
'head_bottom',
|
515 |
-
'head_top',
|
516 |
-
'left_ear',
|
517 |
-
'right_ear',
|
518 |
-
'left_shoulder',
|
519 |
-
'right_shoulder',
|
520 |
-
'left_elbow',
|
521 |
-
'right_elbow',
|
522 |
-
'left_wrist',
|
523 |
-
'right_wrist',
|
524 |
-
'left_hip',
|
525 |
-
'right_hip',
|
526 |
-
'left_knee',
|
527 |
-
'right_knee',
|
528 |
-
'left_ankle',
|
529 |
-
'right_ankle'
|
530 |
-
]
|
531 |
-
|
532 |
-
|
533 |
-
def get_pennaction_joint_names():
|
534 |
-
return [
|
535 |
-
"headtop", # 0
|
536 |
-
"lshoulder", # 1
|
537 |
-
"rshoulder", # 2
|
538 |
-
"lelbow", # 3
|
539 |
-
"relbow", # 4
|
540 |
-
"lwrist", # 5
|
541 |
-
"rwrist", # 6
|
542 |
-
"lhip" , # 7
|
543 |
-
"rhip" , # 8
|
544 |
-
"lknee", # 9
|
545 |
-
"rknee" , # 10
|
546 |
-
"lankle", # 11
|
547 |
-
"rankle" # 12
|
548 |
-
]
|
549 |
-
|
550 |
-
|
551 |
-
def get_common_joint_names():
|
552 |
-
return [
|
553 |
-
"rankle", # 0 "lankle", # 0
|
554 |
-
"rknee", # 1 "lknee", # 1
|
555 |
-
"rhip", # 2 "lhip", # 2
|
556 |
-
"lhip", # 3 "rhip", # 3
|
557 |
-
"lknee", # 4 "rknee", # 4
|
558 |
-
"lankle", # 5 "rankle", # 5
|
559 |
-
"rwrist", # 6 "lwrist", # 6
|
560 |
-
"relbow", # 7 "lelbow", # 7
|
561 |
-
"rshoulder", # 8 "lshoulder", # 8
|
562 |
-
"lshoulder", # 9 "rshoulder", # 9
|
563 |
-
"lelbow", # 10 "relbow", # 10
|
564 |
-
"lwrist", # 11 "rwrist", # 11
|
565 |
-
"neck", # 12 "neck", # 12
|
566 |
-
"headtop", # 13 "headtop", # 13
|
567 |
-
]
|
568 |
-
|
569 |
-
|
570 |
-
def get_common_paper_joint_names():
|
571 |
-
return [
|
572 |
-
"Right Ankle", # 0 "lankle", # 0
|
573 |
-
"Right Knee", # 1 "lknee", # 1
|
574 |
-
"Right Hip", # 2 "lhip", # 2
|
575 |
-
"Left Hip", # 3 "rhip", # 3
|
576 |
-
"Left Knee", # 4 "rknee", # 4
|
577 |
-
"Left Ankle", # 5 "rankle", # 5
|
578 |
-
"Right Wrist", # 6 "lwrist", # 6
|
579 |
-
"Right Elbow", # 7 "lelbow", # 7
|
580 |
-
"Right Shoulder", # 8 "lshoulder", # 8
|
581 |
-
"Left Shoulder", # 9 "rshoulder", # 9
|
582 |
-
"Left Elbow", # 10 "relbow", # 10
|
583 |
-
"Left Wrist", # 11 "rwrist", # 11
|
584 |
-
"Neck", # 12 "neck", # 12
|
585 |
-
"Head", # 13 "headtop", # 13
|
586 |
-
]
|
587 |
-
|
588 |
-
|
589 |
-
def get_common_skeleton():
|
590 |
-
return np.array(
|
591 |
-
[
|
592 |
-
[ 0, 1 ],
|
593 |
-
[ 1, 2 ],
|
594 |
-
[ 3, 4 ],
|
595 |
-
[ 4, 5 ],
|
596 |
-
[ 6, 7 ],
|
597 |
-
[ 7, 8 ],
|
598 |
-
[ 8, 2 ],
|
599 |
-
[ 8, 9 ],
|
600 |
-
[ 9, 3 ],
|
601 |
-
[ 2, 3 ],
|
602 |
-
[ 8, 12],
|
603 |
-
[ 9, 10],
|
604 |
-
[12, 9 ],
|
605 |
-
[10, 11],
|
606 |
-
[12, 13],
|
607 |
-
]
|
608 |
-
)
|
609 |
-
|
610 |
-
|
611 |
-
def get_coco_joint_names():
|
612 |
-
return [
|
613 |
-
"nose", # 0
|
614 |
-
"leye", # 1
|
615 |
-
"reye", # 2
|
616 |
-
"lear", # 3
|
617 |
-
"rear", # 4
|
618 |
-
"lshoulder", # 5
|
619 |
-
"rshoulder", # 6
|
620 |
-
"lelbow", # 7
|
621 |
-
"relbow", # 8
|
622 |
-
"lwrist", # 9
|
623 |
-
"rwrist", # 10
|
624 |
-
"lhip", # 11
|
625 |
-
"rhip", # 12
|
626 |
-
"lknee", # 13
|
627 |
-
"rknee", # 14
|
628 |
-
"lankle", # 15
|
629 |
-
"rankle", # 16
|
630 |
-
]
|
631 |
-
|
632 |
-
|
633 |
-
def get_ochuman_joint_names():
|
634 |
-
return [
|
635 |
-
'rshoulder',
|
636 |
-
'relbow',
|
637 |
-
'rwrist',
|
638 |
-
'lshoulder',
|
639 |
-
'lelbow',
|
640 |
-
'lwrist',
|
641 |
-
'rhip',
|
642 |
-
'rknee',
|
643 |
-
'rankle',
|
644 |
-
'lhip',
|
645 |
-
'lknee',
|
646 |
-
'lankle',
|
647 |
-
'headtop',
|
648 |
-
'neck',
|
649 |
-
'rear',
|
650 |
-
'lear',
|
651 |
-
'nose',
|
652 |
-
'reye',
|
653 |
-
'leye'
|
654 |
-
]
|
655 |
-
|
656 |
-
|
657 |
-
def get_crowdpose_joint_names():
|
658 |
-
return [
|
659 |
-
'lshoulder',
|
660 |
-
'rshoulder',
|
661 |
-
'lelbow',
|
662 |
-
'relbow',
|
663 |
-
'lwrist',
|
664 |
-
'rwrist',
|
665 |
-
'lhip',
|
666 |
-
'rhip',
|
667 |
-
'lknee',
|
668 |
-
'rknee',
|
669 |
-
'lankle',
|
670 |
-
'rankle',
|
671 |
-
'headtop',
|
672 |
-
'neck'
|
673 |
-
]
|
674 |
-
|
675 |
-
def get_coco_skeleton():
|
676 |
-
# 0 - nose,
|
677 |
-
# 1 - leye,
|
678 |
-
# 2 - reye,
|
679 |
-
# 3 - lear,
|
680 |
-
# 4 - rear,
|
681 |
-
# 5 - lshoulder,
|
682 |
-
# 6 - rshoulder,
|
683 |
-
# 7 - lelbow,
|
684 |
-
# 8 - relbow,
|
685 |
-
# 9 - lwrist,
|
686 |
-
# 10 - rwrist,
|
687 |
-
# 11 - lhip,
|
688 |
-
# 12 - rhip,
|
689 |
-
# 13 - lknee,
|
690 |
-
# 14 - rknee,
|
691 |
-
# 15 - lankle,
|
692 |
-
# 16 - rankle,
|
693 |
-
return np.array(
|
694 |
-
[
|
695 |
-
[15, 13],
|
696 |
-
[13, 11],
|
697 |
-
[16, 14],
|
698 |
-
[14, 12],
|
699 |
-
[11, 12],
|
700 |
-
[ 5, 11],
|
701 |
-
[ 6, 12],
|
702 |
-
[ 5, 6 ],
|
703 |
-
[ 5, 7 ],
|
704 |
-
[ 6, 8 ],
|
705 |
-
[ 7, 9 ],
|
706 |
-
[ 8, 10],
|
707 |
-
[ 1, 2 ],
|
708 |
-
[ 0, 1 ],
|
709 |
-
[ 0, 2 ],
|
710 |
-
[ 1, 3 ],
|
711 |
-
[ 2, 4 ],
|
712 |
-
[ 3, 5 ],
|
713 |
-
[ 4, 6 ]
|
714 |
-
]
|
715 |
-
)
|
716 |
-
|
717 |
-
|
718 |
-
def get_mpii_joint_names():
|
719 |
-
return [
|
720 |
-
"rankle", # 0
|
721 |
-
"rknee", # 1
|
722 |
-
"rhip", # 2
|
723 |
-
"lhip", # 3
|
724 |
-
"lknee", # 4
|
725 |
-
"lankle", # 5
|
726 |
-
"hip", # 6
|
727 |
-
"thorax", # 7
|
728 |
-
"neck", # 8
|
729 |
-
"headtop", # 9
|
730 |
-
"rwrist", # 10
|
731 |
-
"relbow", # 11
|
732 |
-
"rshoulder", # 12
|
733 |
-
"lshoulder", # 13
|
734 |
-
"lelbow", # 14
|
735 |
-
"lwrist", # 15
|
736 |
-
]
|
737 |
-
|
738 |
-
|
739 |
-
def get_mpii_skeleton():
|
740 |
-
# 0 - rankle,
|
741 |
-
# 1 - rknee,
|
742 |
-
# 2 - rhip,
|
743 |
-
# 3 - lhip,
|
744 |
-
# 4 - lknee,
|
745 |
-
# 5 - lankle,
|
746 |
-
# 6 - hip,
|
747 |
-
# 7 - thorax,
|
748 |
-
# 8 - neck,
|
749 |
-
# 9 - headtop,
|
750 |
-
# 10 - rwrist,
|
751 |
-
# 11 - relbow,
|
752 |
-
# 12 - rshoulder,
|
753 |
-
# 13 - lshoulder,
|
754 |
-
# 14 - lelbow,
|
755 |
-
# 15 - lwrist,
|
756 |
-
return np.array(
|
757 |
-
[
|
758 |
-
[ 0, 1 ],
|
759 |
-
[ 1, 2 ],
|
760 |
-
[ 2, 6 ],
|
761 |
-
[ 6, 3 ],
|
762 |
-
[ 3, 4 ],
|
763 |
-
[ 4, 5 ],
|
764 |
-
[ 6, 7 ],
|
765 |
-
[ 7, 8 ],
|
766 |
-
[ 8, 9 ],
|
767 |
-
[ 7, 12],
|
768 |
-
[12, 11],
|
769 |
-
[11, 10],
|
770 |
-
[ 7, 13],
|
771 |
-
[13, 14],
|
772 |
-
[14, 15]
|
773 |
-
]
|
774 |
-
)
|
775 |
-
|
776 |
-
|
777 |
-
def get_aich_joint_names():
|
778 |
-
return [
|
779 |
-
"rshoulder", # 0
|
780 |
-
"relbow", # 1
|
781 |
-
"rwrist", # 2
|
782 |
-
"lshoulder", # 3
|
783 |
-
"lelbow", # 4
|
784 |
-
"lwrist", # 5
|
785 |
-
"rhip", # 6
|
786 |
-
"rknee", # 7
|
787 |
-
"rankle", # 8
|
788 |
-
"lhip", # 9
|
789 |
-
"lknee", # 10
|
790 |
-
"lankle", # 11
|
791 |
-
"headtop", # 12
|
792 |
-
"neck", # 13
|
793 |
-
]
|
794 |
-
|
795 |
-
|
796 |
-
def get_aich_skeleton():
|
797 |
-
# 0 - rshoulder,
|
798 |
-
# 1 - relbow,
|
799 |
-
# 2 - rwrist,
|
800 |
-
# 3 - lshoulder,
|
801 |
-
# 4 - lelbow,
|
802 |
-
# 5 - lwrist,
|
803 |
-
# 6 - rhip,
|
804 |
-
# 7 - rknee,
|
805 |
-
# 8 - rankle,
|
806 |
-
# 9 - lhip,
|
807 |
-
# 10 - lknee,
|
808 |
-
# 11 - lankle,
|
809 |
-
# 12 - headtop,
|
810 |
-
# 13 - neck,
|
811 |
-
return np.array(
|
812 |
-
[
|
813 |
-
[ 0, 1 ],
|
814 |
-
[ 1, 2 ],
|
815 |
-
[ 3, 4 ],
|
816 |
-
[ 4, 5 ],
|
817 |
-
[ 6, 7 ],
|
818 |
-
[ 7, 8 ],
|
819 |
-
[ 9, 10],
|
820 |
-
[10, 11],
|
821 |
-
[12, 13],
|
822 |
-
[13, 0 ],
|
823 |
-
[13, 3 ],
|
824 |
-
[ 0, 6 ],
|
825 |
-
[ 3, 9 ]
|
826 |
-
]
|
827 |
-
)
|
828 |
-
|
829 |
-
|
830 |
-
def get_3dpw_joint_names():
|
831 |
-
return [
|
832 |
-
"nose", # 0
|
833 |
-
"thorax", # 1
|
834 |
-
"rshoulder", # 2
|
835 |
-
"relbow", # 3
|
836 |
-
"rwrist", # 4
|
837 |
-
"lshoulder", # 5
|
838 |
-
"lelbow", # 6
|
839 |
-
"lwrist", # 7
|
840 |
-
"rhip", # 8
|
841 |
-
"rknee", # 9
|
842 |
-
"rankle", # 10
|
843 |
-
"lhip", # 11
|
844 |
-
"lknee", # 12
|
845 |
-
"lankle", # 13
|
846 |
-
]
|
847 |
-
|
848 |
-
|
849 |
-
def get_3dpw_skeleton():
|
850 |
-
return np.array(
|
851 |
-
[
|
852 |
-
[ 0, 1 ],
|
853 |
-
[ 1, 2 ],
|
854 |
-
[ 2, 3 ],
|
855 |
-
[ 3, 4 ],
|
856 |
-
[ 1, 5 ],
|
857 |
-
[ 5, 6 ],
|
858 |
-
[ 6, 7 ],
|
859 |
-
[ 2, 8 ],
|
860 |
-
[ 5, 11],
|
861 |
-
[ 8, 11],
|
862 |
-
[ 8, 9 ],
|
863 |
-
[ 9, 10],
|
864 |
-
[11, 12],
|
865 |
-
[12, 13]
|
866 |
-
]
|
867 |
-
)
|
868 |
-
|
869 |
-
|
870 |
-
def get_smplcoco_joint_names():
|
871 |
-
return [
|
872 |
-
"rankle", # 0
|
873 |
-
"rknee", # 1
|
874 |
-
"rhip", # 2
|
875 |
-
"lhip", # 3
|
876 |
-
"lknee", # 4
|
877 |
-
"lankle", # 5
|
878 |
-
"rwrist", # 6
|
879 |
-
"relbow", # 7
|
880 |
-
"rshoulder", # 8
|
881 |
-
"lshoulder", # 9
|
882 |
-
"lelbow", # 10
|
883 |
-
"lwrist", # 11
|
884 |
-
"neck", # 12
|
885 |
-
"headtop", # 13
|
886 |
-
"nose", # 14
|
887 |
-
"leye", # 15
|
888 |
-
"reye", # 16
|
889 |
-
"lear", # 17
|
890 |
-
"rear", # 18
|
891 |
-
]
|
892 |
-
|
893 |
-
|
894 |
-
def get_smplcoco_skeleton():
|
895 |
-
return np.array(
|
896 |
-
[
|
897 |
-
[ 0, 1 ],
|
898 |
-
[ 1, 2 ],
|
899 |
-
[ 3, 4 ],
|
900 |
-
[ 4, 5 ],
|
901 |
-
[ 6, 7 ],
|
902 |
-
[ 7, 8 ],
|
903 |
-
[ 8, 12],
|
904 |
-
[12, 9 ],
|
905 |
-
[ 9, 10],
|
906 |
-
[10, 11],
|
907 |
-
[12, 13],
|
908 |
-
[14, 15],
|
909 |
-
[15, 17],
|
910 |
-
[16, 18],
|
911 |
-
[14, 16],
|
912 |
-
[ 8, 2 ],
|
913 |
-
[ 9, 3 ],
|
914 |
-
[ 2, 3 ],
|
915 |
-
]
|
916 |
-
)
|
917 |
-
|
918 |
-
|
919 |
-
def get_smpl_joint_names():
|
920 |
-
return [
|
921 |
-
'hips', # 0
|
922 |
-
'leftUpLeg', # 1
|
923 |
-
'rightUpLeg', # 2
|
924 |
-
'spine', # 3
|
925 |
-
'leftLeg', # 4
|
926 |
-
'rightLeg', # 5
|
927 |
-
'spine1', # 6
|
928 |
-
'leftFoot', # 7
|
929 |
-
'rightFoot', # 8
|
930 |
-
'spine2', # 9
|
931 |
-
'leftToeBase', # 10
|
932 |
-
'rightToeBase', # 11
|
933 |
-
'neck', # 12
|
934 |
-
'leftShoulder', # 13
|
935 |
-
'rightShoulder', # 14
|
936 |
-
'head', # 15
|
937 |
-
'leftArm', # 16
|
938 |
-
'rightArm', # 17
|
939 |
-
'leftForeArm', # 18
|
940 |
-
'rightForeArm', # 19
|
941 |
-
'leftHand', # 20
|
942 |
-
'rightHand', # 21
|
943 |
-
'leftHandIndex1', # 22
|
944 |
-
'rightHandIndex1', # 23
|
945 |
-
]
|
946 |
-
|
947 |
-
|
948 |
-
def get_smpl_paper_joint_names():
|
949 |
-
return [
|
950 |
-
'Hips', # 0
|
951 |
-
'Left Hip', # 1
|
952 |
-
'Right Hip', # 2
|
953 |
-
'Spine', # 3
|
954 |
-
'Left Knee', # 4
|
955 |
-
'Right Knee', # 5
|
956 |
-
'Spine_1', # 6
|
957 |
-
'Left Ankle', # 7
|
958 |
-
'Right Ankle', # 8
|
959 |
-
'Spine_2', # 9
|
960 |
-
'Left Toe', # 10
|
961 |
-
'Right Toe', # 11
|
962 |
-
'Neck', # 12
|
963 |
-
'Left Shoulder', # 13
|
964 |
-
'Right Shoulder', # 14
|
965 |
-
'Head', # 15
|
966 |
-
'Left Arm', # 16
|
967 |
-
'Right Arm', # 17
|
968 |
-
'Left Elbow', # 18
|
969 |
-
'Right Elbow', # 19
|
970 |
-
'Left Hand', # 20
|
971 |
-
'Right Hand', # 21
|
972 |
-
'Left Thumb', # 22
|
973 |
-
'Right Thumb', # 23
|
974 |
-
]
|
975 |
-
|
976 |
-
|
977 |
-
def get_smpl_neighbor_triplets():
|
978 |
-
return [
|
979 |
-
[ 0, 1, 2 ], # 0
|
980 |
-
[ 1, 4, 0 ], # 1
|
981 |
-
[ 2, 0, 5 ], # 2
|
982 |
-
[ 3, 0, 6 ], # 3
|
983 |
-
[ 4, 7, 1 ], # 4
|
984 |
-
[ 5, 2, 8 ], # 5
|
985 |
-
[ 6, 3, 9 ], # 6
|
986 |
-
[ 7, 10, 4 ], # 7
|
987 |
-
[ 8, 5, 11], # 8
|
988 |
-
[ 9, 13, 14], # 9
|
989 |
-
[10, 7, 4 ], # 10
|
990 |
-
[11, 8, 5 ], # 11
|
991 |
-
[12, 9, 15], # 12
|
992 |
-
[13, 16, 9 ], # 13
|
993 |
-
[14, 9, 17], # 14
|
994 |
-
[15, 9, 12], # 15
|
995 |
-
[16, 18, 13], # 16
|
996 |
-
[17, 14, 19], # 17
|
997 |
-
[18, 20, 16], # 18
|
998 |
-
[19, 17, 21], # 19
|
999 |
-
[20, 22, 18], # 20
|
1000 |
-
[21, 19, 23], # 21
|
1001 |
-
[22, 20, 18], # 22
|
1002 |
-
[23, 19, 21], # 23
|
1003 |
-
]
|
1004 |
-
|
1005 |
-
|
1006 |
-
def get_smpl_skeleton():
|
1007 |
-
return np.array(
|
1008 |
-
[
|
1009 |
-
[ 0, 1 ],
|
1010 |
-
[ 0, 2 ],
|
1011 |
-
[ 0, 3 ],
|
1012 |
-
[ 1, 4 ],
|
1013 |
-
[ 2, 5 ],
|
1014 |
-
[ 3, 6 ],
|
1015 |
-
[ 4, 7 ],
|
1016 |
-
[ 5, 8 ],
|
1017 |
-
[ 6, 9 ],
|
1018 |
-
[ 7, 10],
|
1019 |
-
[ 8, 11],
|
1020 |
-
[ 9, 12],
|
1021 |
-
[ 9, 13],
|
1022 |
-
[ 9, 14],
|
1023 |
-
[12, 15],
|
1024 |
-
[13, 16],
|
1025 |
-
[14, 17],
|
1026 |
-
[16, 18],
|
1027 |
-
[17, 19],
|
1028 |
-
[18, 20],
|
1029 |
-
[19, 21],
|
1030 |
-
[20, 22],
|
1031 |
-
[21, 23],
|
1032 |
-
]
|
1033 |
-
)
|
1034 |
-
|
1035 |
-
|
1036 |
-
def map_spin_joints_to_smpl():
|
1037 |
-
# this function primarily will be used to copy 2D keypoint
|
1038 |
-
# confidences to pose parameters
|
1039 |
-
return [
|
1040 |
-
[(39, 27, 28), 0], # hip,lhip,rhip->hips
|
1041 |
-
[(28,), 1], # lhip->leftUpLeg
|
1042 |
-
[(27,), 2], # rhip->rightUpLeg
|
1043 |
-
[(41, 27, 28, 39), 3], # Spine->spine
|
1044 |
-
[(29,), 4], # lknee->leftLeg
|
1045 |
-
[(26,), 5], # rknee->rightLeg
|
1046 |
-
[(41, 40, 33, 34,), 6], # spine, thorax ->spine1
|
1047 |
-
[(30,), 7], # lankle->leftFoot
|
1048 |
-
[(25,), 8], # rankle->rightFoot
|
1049 |
-
[(40, 33, 34), 9], # thorax,shoulders->spine2
|
1050 |
-
[(30,), 10], # lankle -> leftToe
|
1051 |
-
[(25,), 11], # rankle -> rightToe
|
1052 |
-
[(37, 42, 33, 34), 12], # neck, shoulders -> neck
|
1053 |
-
[(34,), 13], # lshoulder->leftShoulder
|
1054 |
-
[(33,), 14], # rshoulder->rightShoulder
|
1055 |
-
[(33, 34, 38, 43, 44, 45, 46, 47, 48,), 15], # nose, eyes, ears, headtop, shoulders->head
|
1056 |
-
[(34,), 16], # lshoulder->leftArm
|
1057 |
-
[(33,), 17], # rshoulder->rightArm
|
1058 |
-
[(35,), 18], # lelbow->leftForeArm
|
1059 |
-
[(32,), 19], # relbow->rightForeArm
|
1060 |
-
[(36,), 20], # lwrist->leftHand
|
1061 |
-
[(31,), 21], # rwrist->rightHand
|
1062 |
-
[(36,), 22], # lhand -> leftHandIndex
|
1063 |
-
[(31,), 23], # rhand -> rightHandIndex
|
1064 |
-
]
|
1065 |
-
|
1066 |
-
|
1067 |
-
def map_smpl_to_common():
|
1068 |
-
return [
|
1069 |
-
[(11, 8), 0], # rightToe, rightFoot -> rankle
|
1070 |
-
[(5,), 1], # rightleg -> rknee,
|
1071 |
-
[(2,), 2], # rhip
|
1072 |
-
[(1,), 3], # lhip
|
1073 |
-
[(4,), 4], # leftLeg -> lknee
|
1074 |
-
[(10, 7), 5], # lefttoe, leftfoot -> lankle
|
1075 |
-
[(21, 23), 6], # rwrist
|
1076 |
-
[(18,), 7], # relbow
|
1077 |
-
[(17, 14), 8], # rshoulder
|
1078 |
-
[(16, 13), 9], # lshoulder
|
1079 |
-
[(19,), 10], # lelbow
|
1080 |
-
[(20, 22), 11], # lwrist
|
1081 |
-
[(0, 3, 6, 9, 12), 12], # neck
|
1082 |
-
[(15,), 13], # headtop
|
1083 |
-
]
|
1084 |
-
|
1085 |
-
|
1086 |
-
def relation_among_spin_joints():
|
1087 |
-
# this function primarily will be used to copy 2D keypoint
|
1088 |
-
# confidences to 3D joints
|
1089 |
-
return [
|
1090 |
-
[(), 25],
|
1091 |
-
[(), 26],
|
1092 |
-
[(39,), 27],
|
1093 |
-
[(39,), 28],
|
1094 |
-
[(), 29],
|
1095 |
-
[(), 30],
|
1096 |
-
[(), 31],
|
1097 |
-
[(), 32],
|
1098 |
-
[(), 33],
|
1099 |
-
[(), 34],
|
1100 |
-
[(), 35],
|
1101 |
-
[(), 36],
|
1102 |
-
[(40,42,44,43,38,33,34,), 37],
|
1103 |
-
[(43,44,45,46,47,48,33,34,), 38],
|
1104 |
-
[(27,28,), 39],
|
1105 |
-
[(27,28,37,41,42,), 40],
|
1106 |
-
[(27,28,39,40,), 41],
|
1107 |
-
[(37,38,44,45,46,47,48,), 42],
|
1108 |
-
[(44,45,46,47,48,38,42,37,33,34,), 43],
|
1109 |
-
[(44,45,46,47,48,38,42,37,33,34), 44],
|
1110 |
-
[(44,45,46,47,48,38,42,37,33,34), 45],
|
1111 |
-
[(44,45,46,47,48,38,42,37,33,34), 46],
|
1112 |
-
[(44,45,46,47,48,38,42,37,33,34), 47],
|
1113 |
-
[(44,45,46,47,48,38,42,37,33,34), 48],
|
1114 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/loss.py
DELETED
@@ -1,207 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
from common import constants
|
4 |
-
from models.smpl import SMPL
|
5 |
-
from smplx import SMPLX
|
6 |
-
import pickle as pkl
|
7 |
-
import numpy as np
|
8 |
-
from utils.mesh_utils import save_results_mesh
|
9 |
-
from utils.diff_renderer import Pytorch3D
|
10 |
-
import os
|
11 |
-
import cv2
|
12 |
-
|
13 |
-
|
14 |
-
class sem_loss_function(nn.Module):
|
15 |
-
def __init__(self):
|
16 |
-
super(sem_loss_function, self).__init__()
|
17 |
-
self.ce = nn.BCELoss()
|
18 |
-
|
19 |
-
def forward(self, y_true, y_pred):
|
20 |
-
loss = self.ce(y_pred, y_true)
|
21 |
-
return loss
|
22 |
-
|
23 |
-
|
24 |
-
class class_loss_function(nn.Module):
|
25 |
-
def __init__(self):
|
26 |
-
super(class_loss_function, self).__init__()
|
27 |
-
self.ce_loss = nn.BCELoss()
|
28 |
-
# self.ce_loss = nn.MultiLabelSoftMarginLoss()
|
29 |
-
# self.ce_loss = nn.MultiLabelMarginLoss()
|
30 |
-
|
31 |
-
def forward(self, y_true, y_pred, valid_mask):
|
32 |
-
# y_true = torch.squeeze(y_true, 1).long()
|
33 |
-
# y_true = torch.squeeze(y_true, 1)
|
34 |
-
# y_pred = torch.squeeze(y_pred, 1)
|
35 |
-
bs = y_true.shape[0]
|
36 |
-
if bs != 1:
|
37 |
-
y_pred = y_pred[valid_mask == 1]
|
38 |
-
y_true = y_true[valid_mask == 1]
|
39 |
-
if len(y_pred) > 0:
|
40 |
-
return self.ce_loss(y_pred, y_true)
|
41 |
-
else:
|
42 |
-
return torch.tensor(0.0).to(y_pred.device)
|
43 |
-
|
44 |
-
|
45 |
-
class pixel_anchoring_function(nn.Module):
|
46 |
-
def __init__(self, model_type, device='cuda'):
|
47 |
-
super(pixel_anchoring_function, self).__init__()
|
48 |
-
|
49 |
-
self.device = device
|
50 |
-
|
51 |
-
self.model_type = model_type
|
52 |
-
|
53 |
-
if self.model_type == 'smplx':
|
54 |
-
# load mapping from smpl vertices to smplx vertices
|
55 |
-
mapping_pkl = os.path.join(constants.CONTACT_MAPPING_PATH, "smpl_to_smplx.pkl")
|
56 |
-
with open(mapping_pkl, 'rb') as f:
|
57 |
-
smpl_to_smplx_mapping = pkl.load(f)
|
58 |
-
smpl_to_smplx_mapping = smpl_to_smplx_mapping["matrix"]
|
59 |
-
self.smpl_to_smplx_mapping = torch.from_numpy(smpl_to_smplx_mapping).float().to(self.device)
|
60 |
-
|
61 |
-
|
62 |
-
# Setup the SMPL model
|
63 |
-
if self.model_type == 'smpl':
|
64 |
-
self.n_vertices = 6890
|
65 |
-
self.body_model = SMPL(constants.SMPL_MODEL_DIR).to(self.device)
|
66 |
-
if self.model_type == 'smplx':
|
67 |
-
self.n_vertices = 10475
|
68 |
-
self.body_model = SMPLX(constants.SMPLX_MODEL_DIR,
|
69 |
-
num_betas=10,
|
70 |
-
use_pca=False).to(self.device)
|
71 |
-
self.body_faces = torch.LongTensor(self.body_model.faces.astype(np.int32)).to(self.device)
|
72 |
-
|
73 |
-
self.ce_loss = nn.BCELoss()
|
74 |
-
|
75 |
-
def get_posed_mesh(self, body_params, debug=False):
|
76 |
-
betas = body_params['betas']
|
77 |
-
pose = body_params['pose']
|
78 |
-
transl = body_params['transl']
|
79 |
-
|
80 |
-
# extra smplx params
|
81 |
-
extra_args = {'jaw_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
|
82 |
-
'leye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
|
83 |
-
'reye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
|
84 |
-
'expression': torch.zeros((betas.shape[0], 10)).float().to(self.device),
|
85 |
-
'left_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device),
|
86 |
-
'right_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device)}
|
87 |
-
|
88 |
-
smpl_output = self.body_model(betas=betas,
|
89 |
-
body_pose=pose[:, 3:],
|
90 |
-
global_orient=pose[:, :3],
|
91 |
-
pose2rot=True,
|
92 |
-
transl=transl,
|
93 |
-
**extra_args)
|
94 |
-
smpl_verts = smpl_output.vertices
|
95 |
-
smpl_joints = smpl_output.joints
|
96 |
-
|
97 |
-
if debug:
|
98 |
-
for mesh_i in range(smpl_verts.shape[0]):
|
99 |
-
out_dir = 'temp_meshes'
|
100 |
-
os.makedirs(out_dir, exist_ok=True)
|
101 |
-
out_file = os.path.join(out_dir, f'temp_mesh_{mesh_i:04d}.obj')
|
102 |
-
save_results_mesh(smpl_verts[mesh_i], self.body_model.faces, out_file)
|
103 |
-
return smpl_verts, smpl_joints
|
104 |
-
|
105 |
-
|
106 |
-
def render_batch(self, smpl_verts, cam_k, img_scale_factor, vertex_colors=None, face_textures=None, debug=False):
|
107 |
-
|
108 |
-
bs = smpl_verts.shape[0]
|
109 |
-
|
110 |
-
# Incorporate resizing factor into the camera
|
111 |
-
img_w = 256 # TODO: Remove hardcoding
|
112 |
-
img_h = 256 # TODO: Remove hardcoding
|
113 |
-
focal_length_x = cam_k[:, 0, 0] * img_scale_factor[:, 0]
|
114 |
-
focal_length_y = cam_k[:, 1, 1] * img_scale_factor[:, 1]
|
115 |
-
# convert to float for pytorch3d
|
116 |
-
focal_length_x, focal_length_y = focal_length_x.float(), focal_length_y.float()
|
117 |
-
|
118 |
-
# concatenate focal length
|
119 |
-
focal_length = torch.stack([focal_length_x, focal_length_y], dim=1)
|
120 |
-
|
121 |
-
# Setup renderer
|
122 |
-
renderer = Pytorch3D(img_h=img_h,
|
123 |
-
img_w=img_w,
|
124 |
-
focal_length=focal_length,
|
125 |
-
smpl_faces=self.body_faces,
|
126 |
-
texture_mode='deco',
|
127 |
-
vertex_colors=vertex_colors,
|
128 |
-
face_textures=face_textures,
|
129 |
-
is_train=True,
|
130 |
-
is_cam_batch=True)
|
131 |
-
front_view = renderer(smpl_verts)
|
132 |
-
if debug:
|
133 |
-
# visualize the front view as images in a temp_image folder
|
134 |
-
for i in range(bs):
|
135 |
-
front_view_rgb = front_view[i, :3, :, :].permute(1, 2, 0).detach().cpu()
|
136 |
-
front_view_mask = front_view[i, 3, :, :].detach().cpu()
|
137 |
-
out_dir = 'temp_images'
|
138 |
-
os.makedirs(out_dir, exist_ok=True)
|
139 |
-
out_file_rgb = os.path.join(out_dir, f'{i:04d}_rgb.png')
|
140 |
-
out_file_mask = os.path.join(out_dir, f'{i:04d}_mask.png')
|
141 |
-
cv2.imwrite(out_file_rgb, front_view_rgb.numpy()*255)
|
142 |
-
cv2.imwrite(out_file_mask, front_view_mask.numpy()*255)
|
143 |
-
|
144 |
-
return front_view
|
145 |
-
|
146 |
-
def paint_contact(self, pred_contact):
|
147 |
-
"""
|
148 |
-
Paints the contact vertices on the SMPL mesh
|
149 |
-
|
150 |
-
Args:
|
151 |
-
pred_contact: prbabilities of contact vertices
|
152 |
-
|
153 |
-
Returns:
|
154 |
-
pred_rgb: RGB colors for the contact vertices
|
155 |
-
"""
|
156 |
-
bs = pred_contact.shape[0]
|
157 |
-
|
158 |
-
# initialize black and while colors
|
159 |
-
colors = torch.tensor([[0, 0, 0], [1, 1, 1]]).float().to(self.device)
|
160 |
-
colors = torch.unsqueeze(colors, 0).expand(bs, -1, -1)
|
161 |
-
|
162 |
-
# add another dimension to the contact probabilities for inverse probabilities
|
163 |
-
pred_contact = torch.unsqueeze(pred_contact, 2)
|
164 |
-
pred_contact = torch.cat((1 - pred_contact, pred_contact), 2)
|
165 |
-
|
166 |
-
# get pred_rgb colors
|
167 |
-
pred_vert_rgb = torch.bmm(pred_contact, colors)
|
168 |
-
pred_face_rgb = pred_vert_rgb[:, self.body_faces, :][:, :, 0, :] # take the first vertex color
|
169 |
-
pred_face_texture = torch.zeros((bs, self.body_faces.shape[0], 1, 1, 3), dtype=torch.float32).to(self.device)
|
170 |
-
pred_face_texture[:, :, 0, 0, :] = pred_face_rgb
|
171 |
-
return pred_vert_rgb, pred_face_texture
|
172 |
-
|
173 |
-
def forward(self, pred_contact, body_params, cam_k, img_scale_factor, gt_contact_polygon, valid_mask):
|
174 |
-
"""
|
175 |
-
Takes predicted contact labels (probabilities), transfers them to the posed mesh and
|
176 |
-
renders to the image. Loss is computed between the rendered contact and the ground truth
|
177 |
-
polygons from HOT.
|
178 |
-
|
179 |
-
Args:
|
180 |
-
pred_contact: predicted contact labels (probabilities)
|
181 |
-
body_params: SMPL parameters in camera coords
|
182 |
-
cam_k: camera intrinsics
|
183 |
-
gt_contact_polygon: ground truth polygons from HOT
|
184 |
-
"""
|
185 |
-
# convert pred_contact to smplx
|
186 |
-
bs = pred_contact.shape[0]
|
187 |
-
if self.model_type == 'smplx':
|
188 |
-
smpl_to_smplx_mapping = self.smpl_to_smplx_mapping[None].expand(bs, -1, -1)
|
189 |
-
pred_contact = torch.bmm(smpl_to_smplx_mapping, pred_contact[..., None])
|
190 |
-
pred_contact = pred_contact.squeeze()
|
191 |
-
|
192 |
-
# get the posed mesh
|
193 |
-
smpl_verts, smpl_joints = self.get_posed_mesh(body_params)
|
194 |
-
|
195 |
-
# paint the contact vertices on the mesh
|
196 |
-
vertex_colors, face_textures = self.paint_contact(pred_contact)
|
197 |
-
|
198 |
-
# render the mesh
|
199 |
-
front_view = self.render_batch(smpl_verts, cam_k, img_scale_factor, vertex_colors, face_textures)
|
200 |
-
front_view_rgb = front_view[:, :3, :, :].permute(0, 2, 3, 1)
|
201 |
-
front_view_mask = front_view[:, 3, :, :]
|
202 |
-
|
203 |
-
# compute segmentation loss between rendered contact mask and ground truth contact mask
|
204 |
-
front_view_rgb = front_view_rgb[valid_mask == 1]
|
205 |
-
gt_contact_polygon = gt_contact_polygon[valid_mask == 1]
|
206 |
-
loss = self.ce_loss(front_view_rgb, gt_contact_polygon)
|
207 |
-
return loss, front_view_rgb, front_view_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/mesh_utils.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
import trimesh
|
2 |
-
|
3 |
-
def save_results_mesh(vertices, faces, filename):
|
4 |
-
mesh = trimesh.Trimesh(vertices, faces, process=False)
|
5 |
-
mesh.export(filename)
|
6 |
-
print(f'save results to {filename}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/metrics.py
DELETED
@@ -1,106 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
import monai.metrics as metrics
|
4 |
-
from common.constants import DIST_MATRIX_PATH
|
5 |
-
|
6 |
-
DIST_MATRIX = np.load(DIST_MATRIX_PATH)
|
7 |
-
|
8 |
-
def metric(mask, pred, back=True):
|
9 |
-
iou = metrics.compute_meaniou(pred, mask, back, False)
|
10 |
-
iou = iou.mean()
|
11 |
-
|
12 |
-
return iou
|
13 |
-
|
14 |
-
def precision_recall_f1score(gt, pred):
|
15 |
-
"""
|
16 |
-
Compute precision, recall, and f1
|
17 |
-
"""
|
18 |
-
|
19 |
-
# gt = gt.numpy()
|
20 |
-
# pred = pred.numpy()
|
21 |
-
|
22 |
-
precision = torch.zeros(gt.shape[0])
|
23 |
-
recall = torch.zeros(gt.shape[0])
|
24 |
-
f1 = torch.zeros(gt.shape[0])
|
25 |
-
|
26 |
-
for b in range(gt.shape[0]):
|
27 |
-
tp_num = gt[b, pred[b, :] >= 0.5].sum()
|
28 |
-
precision_denominator = (pred[b, :] >= 0.5).sum()
|
29 |
-
recall_denominator = (gt[b, :]).sum()
|
30 |
-
|
31 |
-
precision_ = tp_num / precision_denominator
|
32 |
-
recall_ = tp_num / recall_denominator
|
33 |
-
if precision_denominator == 0: # if no pred
|
34 |
-
precision_ = 1.
|
35 |
-
recall_ = 0.
|
36 |
-
f1_ = 0.
|
37 |
-
elif recall_denominator == 0: # if no GT
|
38 |
-
precision_ = 0.
|
39 |
-
recall_ = 1.
|
40 |
-
f1_ = 0.
|
41 |
-
elif (precision_ + recall_) <= 1e-10: # to avoid precision issues
|
42 |
-
precision_= 0.
|
43 |
-
recall_= 0.
|
44 |
-
f1_ = 0.
|
45 |
-
else:
|
46 |
-
f1_ = 2 * precision_ * recall_ / (precision_ + recall_)
|
47 |
-
|
48 |
-
precision[b] = precision_
|
49 |
-
recall[b] = recall_
|
50 |
-
f1[b] = f1_
|
51 |
-
|
52 |
-
# return precision, recall, f1
|
53 |
-
return precision, recall, f1
|
54 |
-
|
55 |
-
def acc_precision_recall_f1score(gt, pred):
|
56 |
-
"""
|
57 |
-
Compute acc, precision, recall, and f1
|
58 |
-
"""
|
59 |
-
|
60 |
-
# gt = gt.numpy()
|
61 |
-
# pred = pred.numpy()
|
62 |
-
|
63 |
-
acc = torch.zeros(gt.shape[0])
|
64 |
-
precision = torch.zeros(gt.shape[0])
|
65 |
-
recall = torch.zeros(gt.shape[0])
|
66 |
-
f1 = torch.zeros(gt.shape[0])
|
67 |
-
|
68 |
-
for b in range(gt.shape[0]):
|
69 |
-
tp_num = gt[b, pred[b, :] >= 0.5].sum()
|
70 |
-
precision_denominator = (pred[b, :] >= 0.5).sum()
|
71 |
-
recall_denominator = (gt[b, :]).sum()
|
72 |
-
tn_num = gt.shape[-1] - precision_denominator - recall_denominator + tp_num
|
73 |
-
|
74 |
-
acc_ = (tp_num + tn_num) / gt.shape[-1]
|
75 |
-
precision_ = tp_num / (precision_denominator + 1e-10)
|
76 |
-
recall_ = tp_num / (recall_denominator + 1e-10)
|
77 |
-
f1_ = 2 * precision_ * recall_ / (precision_ + recall_ + 1e-10)
|
78 |
-
|
79 |
-
acc[b] = acc_
|
80 |
-
precision[b] = precision_
|
81 |
-
recall[b] = recall_
|
82 |
-
|
83 |
-
# return precision, recall, f1
|
84 |
-
return acc, precision, recall, f1
|
85 |
-
|
86 |
-
def det_error_metric(pred, gt):
|
87 |
-
|
88 |
-
gt = gt.detach().cpu()
|
89 |
-
pred = pred.detach().cpu()
|
90 |
-
|
91 |
-
dist_matrix = torch.tensor(DIST_MATRIX)
|
92 |
-
|
93 |
-
false_positive_dist = torch.zeros(gt.shape[0])
|
94 |
-
false_negative_dist = torch.zeros(gt.shape[0])
|
95 |
-
|
96 |
-
for b in range(gt.shape[0]):
|
97 |
-
gt_columns = dist_matrix[:, gt[b, :]==1] if any(gt[b, :]==1) else dist_matrix
|
98 |
-
error_matrix = gt_columns[pred[b, :] >= 0.5, :] if any(pred[b, :] >= 0.5) else gt_columns
|
99 |
-
|
100 |
-
false_positive_dist_ = error_matrix.min(dim=1)[0].mean()
|
101 |
-
false_negative_dist_ = error_matrix.min(dim=0)[0].mean()
|
102 |
-
|
103 |
-
false_positive_dist[b] = false_positive_dist_
|
104 |
-
false_negative_dist[b] = false_negative_dist_
|
105 |
-
|
106 |
-
return false_positive_dist, false_negative_dist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/smpl_uv.py
DELETED
@@ -1,167 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import trimesh
|
3 |
-
import numpy as np
|
4 |
-
import skimage.io as io
|
5 |
-
from PIL import Image
|
6 |
-
from smplx import SMPL
|
7 |
-
from matplotlib import cm as mpl_cm, colors as mpl_colors
|
8 |
-
from trimesh.visual.color import face_to_vertex_color, vertex_to_face_color, to_rgba
|
9 |
-
|
10 |
-
from common import constants
|
11 |
-
from .colorwheel import make_color_wheel_image
|
12 |
-
|
13 |
-
|
14 |
-
def get_smpl_uv():
|
15 |
-
uv_obj = 'data/body_models/smpl_uv_20200910/smpl_uv.obj'
|
16 |
-
|
17 |
-
uv_map = []
|
18 |
-
with open(uv_obj) as f:
|
19 |
-
for line in f.readlines():
|
20 |
-
if line.startswith('vt'):
|
21 |
-
coords = [float(x) for x in line.split(' ')[1:]]
|
22 |
-
uv_map.append(coords)
|
23 |
-
|
24 |
-
uv_map = np.array(uv_map)
|
25 |
-
|
26 |
-
return uv_map
|
27 |
-
|
28 |
-
|
29 |
-
def show_uv_texture():
|
30 |
-
# image = io.imread('data/body_models/smpl_uv_20200910/smpl_uv_20200910.png')
|
31 |
-
image = make_color_wheel_image(1024, 1024)
|
32 |
-
image = Image.fromarray(image)
|
33 |
-
|
34 |
-
uv = np.load('data/body_models/smpl_uv_20200910/uv_table.npy') # get_smpl_uv()
|
35 |
-
material = trimesh.visual.texture.SimpleMaterial(image=image)
|
36 |
-
tex_visuals = trimesh.visual.TextureVisuals(uv=uv, image=image, material=material)
|
37 |
-
|
38 |
-
smpl = SMPL(constants.SMPL_MODEL_DIR)
|
39 |
-
|
40 |
-
faces = smpl.faces
|
41 |
-
verts = smpl().vertices[0].detach().numpy()
|
42 |
-
|
43 |
-
# assert(len(uv) == len(verts))
|
44 |
-
print(uv.shape)
|
45 |
-
vc = tex_visuals.to_color().vertex_colors
|
46 |
-
fc = trimesh.visual.color.vertex_to_face_color(vc, faces)
|
47 |
-
face_colors = fc.copy()
|
48 |
-
fc = fc.astype(float)
|
49 |
-
vc = vc.astype(float)
|
50 |
-
fc[:,:3] = fc[:,:3] / 255.
|
51 |
-
vc[:,:3] = vc[:,:3] / 255.
|
52 |
-
print(fc[:,:3].max(), fc[:,:3].min(), fc[:,:3].mean())
|
53 |
-
print(vc[:, :3].max(), vc[:, :3].min(), vc[:, :3].mean())
|
54 |
-
np.save('data/body_models/smpl/color_wheel_face_colors.npy', fc)
|
55 |
-
np.save('data/body_models/smpl/color_wheel_vertex_colors.npy', vc)
|
56 |
-
print(fc.shape)
|
57 |
-
mesh = trimesh.Trimesh(verts, faces, validate=True, process=False, face_colors=face_colors)
|
58 |
-
# mesh = trimesh.load('data/body_models/smpl_uv_20200910/smpl_uv.obj', process=False)
|
59 |
-
# mesh.visual = tex_visuals
|
60 |
-
|
61 |
-
# import ipdb; ipdb.set_trace()
|
62 |
-
# print(vc.shape)
|
63 |
-
mesh.show()
|
64 |
-
|
65 |
-
|
66 |
-
def show_colored_mesh():
|
67 |
-
cm = mpl_cm.get_cmap('jet')
|
68 |
-
norm_gt = mpl_colors.Normalize()
|
69 |
-
|
70 |
-
smpl = SMPL(constants.SMPL_MODEL_DIR)
|
71 |
-
|
72 |
-
faces = smpl.faces
|
73 |
-
verts = smpl().vertices[0].detach().numpy()
|
74 |
-
|
75 |
-
m = trimesh.Trimesh(verts, faces, process=False)
|
76 |
-
|
77 |
-
mode = 1
|
78 |
-
if mode == 0:
|
79 |
-
# mano_segm_labels = m.triangles_center
|
80 |
-
face_labels = m.triangles_center
|
81 |
-
face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
|
82 |
-
|
83 |
-
elif mode == 1:
|
84 |
-
# print(face_labels.shape)
|
85 |
-
face_labels = m.triangles_center
|
86 |
-
face_labels = np.argsort(np.linalg.norm(face_labels, axis=-1))
|
87 |
-
face_colors = np.ones((13776, 4))
|
88 |
-
face_colors[:, 3] = 1.0
|
89 |
-
face_colors[:, :3] = cm(norm_gt(face_labels))[:, :3]
|
90 |
-
elif mode == 2:
|
91 |
-
# breakpoint()
|
92 |
-
fc = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')[0, :, 0, 0, 0, :]
|
93 |
-
face_colors = np.ones((13776, 4))
|
94 |
-
face_colors[:, :3] = fc
|
95 |
-
mesh = trimesh.Trimesh(verts, faces, process=False, face_colors=face_colors)
|
96 |
-
mesh.show()
|
97 |
-
|
98 |
-
|
99 |
-
def get_tenet_texture(mode='smplpix'):
|
100 |
-
# mode = 'smplpix', 'decomr'
|
101 |
-
|
102 |
-
smpl = SMPL(constants.SMPL_MODEL_DIR)
|
103 |
-
|
104 |
-
faces = smpl.faces
|
105 |
-
verts = smpl().vertices[0].detach().numpy()
|
106 |
-
|
107 |
-
m = trimesh.Trimesh(verts, faces, process=False)
|
108 |
-
if mode == 'smplpix':
|
109 |
-
# mano_segm_labels = m.triangles_center
|
110 |
-
face_labels = m.triangles_center
|
111 |
-
face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
|
112 |
-
texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
|
113 |
-
texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
|
114 |
-
texture = torch.from_numpy(texture).float()
|
115 |
-
elif mode == 'decomr':
|
116 |
-
texture = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')
|
117 |
-
texture = torch.from_numpy(texture).float()
|
118 |
-
elif mode == 'colorwheel':
|
119 |
-
face_colors = np.load('data/body_models/smpl/color_wheel_face_colors.npy')
|
120 |
-
texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
|
121 |
-
texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
|
122 |
-
texture = torch.from_numpy(texture).float()
|
123 |
-
else:
|
124 |
-
raise ValueError(f'{mode} is not defined!')
|
125 |
-
|
126 |
-
return texture
|
127 |
-
|
128 |
-
|
129 |
-
def save_tenet_textures(mode='smplpix'):
|
130 |
-
# mode = 'smplpix', 'decomr'
|
131 |
-
|
132 |
-
smpl = SMPL(constants.SMPL_MODEL_DIR)
|
133 |
-
|
134 |
-
faces = smpl.faces
|
135 |
-
verts = smpl().vertices[0].detach().numpy()
|
136 |
-
|
137 |
-
m = trimesh.Trimesh(verts, faces, process=False)
|
138 |
-
|
139 |
-
if mode == 'smplpix':
|
140 |
-
# mano_segm_labels = m.triangles_center
|
141 |
-
face_labels = m.triangles_center
|
142 |
-
face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
|
143 |
-
texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
|
144 |
-
texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
|
145 |
-
texture = torch.from_numpy(texture).float()
|
146 |
-
|
147 |
-
vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
|
148 |
-
|
149 |
-
elif mode == 'decomr':
|
150 |
-
texture = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')
|
151 |
-
texture = torch.from_numpy(texture).float()
|
152 |
-
face_colors = texture[0, :, 0, 0, 0, :]
|
153 |
-
vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
|
154 |
-
|
155 |
-
elif mode == 'colorwheel':
|
156 |
-
face_colors = np.load('data/body_models/smpl/color_wheel_face_colors.npy')
|
157 |
-
texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
|
158 |
-
texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
|
159 |
-
texture = torch.from_numpy(texture).float()
|
160 |
-
face_colors[:, :3] *= 255
|
161 |
-
vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
|
162 |
-
else:
|
163 |
-
raise ValueError(f'{mode} is not defined!')
|
164 |
-
|
165 |
-
print(vert_colors.shape, vert_colors.max())
|
166 |
-
np.save(f'data/body_models/smpl/{mode}_vertex_colors.npy', vert_colors)
|
167 |
-
return texture
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vis/__pycache__/visualize.cpython-37.pyc
DELETED
Binary file (6.7 kB)
|
|
vis/visualize.py
DELETED
@@ -1,209 +0,0 @@
|
|
1 |
-
import cv2
|
2 |
-
import os
|
3 |
-
import trimesh
|
4 |
-
import PIL.Image as pil_img
|
5 |
-
import numpy as np
|
6 |
-
import pyrender
|
7 |
-
from common import constants
|
8 |
-
|
9 |
-
os.environ['PYOPENGL_PLATFORM'] = 'egl'
|
10 |
-
|
11 |
-
def render_image(scene, img_res, img=None, viewer=False):
|
12 |
-
'''
|
13 |
-
Render the given pyrender scene and return the image. Can also overlay the mesh on an image.
|
14 |
-
'''
|
15 |
-
if viewer:
|
16 |
-
pyrender.Viewer(scene, use_raymond_lighting=True)
|
17 |
-
return 0
|
18 |
-
else:
|
19 |
-
r = pyrender.OffscreenRenderer(viewport_width=img_res,
|
20 |
-
viewport_height=img_res,
|
21 |
-
point_size=1.0)
|
22 |
-
color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)
|
23 |
-
color = color.astype(np.float32) / 255.0
|
24 |
-
|
25 |
-
if img is not None:
|
26 |
-
valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
|
27 |
-
input_img = img.detach().cpu().numpy()
|
28 |
-
output_img = (color[:, :, :-1] * valid_mask +
|
29 |
-
(1 - valid_mask) * input_img)
|
30 |
-
else:
|
31 |
-
output_img = color
|
32 |
-
return output_img
|
33 |
-
|
34 |
-
def create_scene(mesh, img, focal_length=500, camera_center=250, img_res=500):
|
35 |
-
# Setup the scene
|
36 |
-
scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0, 1.0],
|
37 |
-
ambient_light=(0.3, 0.3, 0.3))
|
38 |
-
# add mesh for camera
|
39 |
-
camera_pose = np.eye(4)
|
40 |
-
camera_rotation = np.eye(3, 3)
|
41 |
-
camera_translation = np.array([0., 0, 2.5])
|
42 |
-
camera_pose[:3, :3] = camera_rotation
|
43 |
-
camera_pose[:3, 3] = camera_rotation @ camera_translation
|
44 |
-
pyrencamera = pyrender.camera.IntrinsicsCamera(
|
45 |
-
fx=focal_length, fy=focal_length,
|
46 |
-
cx=camera_center, cy=camera_center)
|
47 |
-
scene.add(pyrencamera, pose=camera_pose)
|
48 |
-
# create and add light
|
49 |
-
light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=1)
|
50 |
-
light_pose = np.eye(4)
|
51 |
-
for lp in [[1, 1, 1], [-1, 1, 1], [1, -1, 1], [-1, -1, 1]]:
|
52 |
-
light_pose[:3, 3] = mesh.vertices.mean(0) + np.array(lp)
|
53 |
-
# out_mesh.vertices.mean(0) + np.array(lp)
|
54 |
-
scene.add(light, pose=light_pose)
|
55 |
-
# add body mesh
|
56 |
-
material = pyrender.MetallicRoughnessMaterial(
|
57 |
-
metallicFactor=0.0,
|
58 |
-
alphaMode='OPAQUE',
|
59 |
-
baseColorFactor=(1.0, 1.0, 0.9, 1.0))
|
60 |
-
mesh_images = []
|
61 |
-
|
62 |
-
# resize input image to fit the mesh image height
|
63 |
-
# print(img.shape)
|
64 |
-
img_height = img_res
|
65 |
-
img_width = int(img_height * img.shape[1] / img.shape[0])
|
66 |
-
img = cv2.resize(img, (img_width, img_height))
|
67 |
-
mesh_images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
68 |
-
|
69 |
-
for sideview_angle in [0, 90, 180, 270]:
|
70 |
-
out_mesh = mesh.copy()
|
71 |
-
rot = trimesh.transformations.rotation_matrix(
|
72 |
-
np.radians(sideview_angle), [0, 1, 0])
|
73 |
-
out_mesh.apply_transform(rot)
|
74 |
-
out_mesh = pyrender.Mesh.from_trimesh(
|
75 |
-
out_mesh,
|
76 |
-
material=material)
|
77 |
-
mesh_pose = np.eye(4)
|
78 |
-
scene.add(out_mesh, pose=mesh_pose, name='mesh')
|
79 |
-
output_img = render_image(scene, img_res)
|
80 |
-
output_img = pil_img.fromarray((output_img * 255).astype(np.uint8))
|
81 |
-
output_img = np.asarray(output_img)[:, :, :3]
|
82 |
-
mesh_images.append(output_img)
|
83 |
-
# delete the previous mesh
|
84 |
-
prev_mesh = scene.get_nodes(name='mesh').pop()
|
85 |
-
scene.remove_node(prev_mesh)
|
86 |
-
|
87 |
-
# show upside down view
|
88 |
-
for topview_angle in [90, 270]:
|
89 |
-
out_mesh = mesh.copy()
|
90 |
-
rot = trimesh.transformations.rotation_matrix(
|
91 |
-
np.radians(topview_angle), [1, 0, 0])
|
92 |
-
out_mesh.apply_transform(rot)
|
93 |
-
out_mesh = pyrender.Mesh.from_trimesh(
|
94 |
-
out_mesh,
|
95 |
-
material=material)
|
96 |
-
mesh_pose = np.eye(4)
|
97 |
-
scene.add(out_mesh, pose=mesh_pose, name='mesh')
|
98 |
-
output_img = render_image(scene, img_res)
|
99 |
-
output_img = pil_img.fromarray((output_img * 255).astype(np.uint8))
|
100 |
-
output_img = np.asarray(output_img)[:, :, :3]
|
101 |
-
mesh_images.append(output_img)
|
102 |
-
# delete the previous mesh
|
103 |
-
prev_mesh = scene.get_nodes(name='mesh').pop()
|
104 |
-
scene.remove_node(prev_mesh)
|
105 |
-
|
106 |
-
# stack images
|
107 |
-
IMG = np.hstack(mesh_images)
|
108 |
-
IMG = pil_img.fromarray(IMG)
|
109 |
-
IMG.thumbnail((3000, 3000))
|
110 |
-
return IMG
|
111 |
-
|
112 |
-
# img = cv2.imread('../samples/prox_N3OpenArea_03301_01_s001_frame_00694.jpg')
|
113 |
-
# mesh = trimesh.load('../samples/mesh.ply', process=False)
|
114 |
-
# comb_img = create_scene(mesh, img)
|
115 |
-
# comb_img.save('../samples/combined_image.png')
|
116 |
-
|
117 |
-
def unsplit(img, palette):
|
118 |
-
rgb_img = np.zeros((img.shape[0], img.shape[1], 3))
|
119 |
-
for i in range(img.shape[0]):
|
120 |
-
for j in range(img.shape[1]):
|
121 |
-
id = np.argmax(img[i, j, :])
|
122 |
-
rgb_img[i, j, :] = palette[id]
|
123 |
-
|
124 |
-
return rgb_img
|
125 |
-
|
126 |
-
def gen_render(output, normalize=True):
|
127 |
-
img = output['img'].cpu().numpy()
|
128 |
-
contact_labels_3d = output['contact_labels_3d_gt'].cpu().numpy()
|
129 |
-
contact_labels_3d_pred = output['contact_labels_3d_pred'].cpu().numpy()
|
130 |
-
sem_mask_gt = output['sem_mask_gt'].cpu().numpy()
|
131 |
-
sem_mask_pred = output['sem_mask_pred'].cpu().numpy()
|
132 |
-
part_mask_gt = output['part_mask_gt'].cpu().numpy()
|
133 |
-
part_mask_pred = output['part_mask_pred'].cpu().numpy()
|
134 |
-
contact_2d_gt_rgb = output['contact_2d_gt'].cpu().numpy()
|
135 |
-
contact_2d_pred_rgb = output['contact_2d_pred_rgb'].cpu().numpy()
|
136 |
-
|
137 |
-
mesh_path = './data/smpl/smpl_neutral_tpose.ply'
|
138 |
-
gt_mesh = trimesh.load(mesh_path, process=False)
|
139 |
-
pred_mesh = trimesh.load(mesh_path, process=False)
|
140 |
-
|
141 |
-
img = np.transpose(img[0], (1, 2, 0))
|
142 |
-
if normalize:
|
143 |
-
# unnormalize the image before displaying
|
144 |
-
mean = np.array(constants.IMG_NORM_MEAN, dtype=np.float32)
|
145 |
-
std = np.array(constants.IMG_NORM_STD, dtype=np.float32)
|
146 |
-
img = img * std + mean
|
147 |
-
img = img * 255
|
148 |
-
img = img.astype(np.uint8)
|
149 |
-
color = np.array([0, 0, 0, 255])
|
150 |
-
th = 0.5
|
151 |
-
|
152 |
-
contact_labels_3d = contact_labels_3d[0, :]
|
153 |
-
for vid, val in enumerate(contact_labels_3d):
|
154 |
-
if val >= th:
|
155 |
-
gt_mesh.visual.vertex_colors[vid] = color
|
156 |
-
|
157 |
-
contact_labels_3d_pred = contact_labels_3d_pred[0, :]
|
158 |
-
for vid, val in enumerate(contact_labels_3d_pred):
|
159 |
-
if val >= th:
|
160 |
-
pred_mesh.visual.vertex_colors[vid] = color
|
161 |
-
|
162 |
-
gt_rend = create_scene(gt_mesh, img)
|
163 |
-
pred_rend = create_scene(pred_mesh, img)
|
164 |
-
|
165 |
-
sem_palette = [[220, 20, 60], [119, 11, 32], [0, 0, 142], [0, 0, 230], [106, 0, 228], [0, 60, 100], [0, 80, 100], [0, 0, 70], [0, 0, 192], [250, 170, 30], [100, 170, 30], [220, 220, 0], [175, 116, 175], [250, 0, 30], [165, 42, 42], [255, 77, 255], [0, 226, 252], [182, 182, 255], [0, 82, 0], [120, 166, 157], [110, 76, 0], [174, 57, 255], [199, 100, 0], [72, 0, 118], [255, 179, 240], [0, 125, 92], [209, 0, 151], [188, 208, 182], [0, 220, 176], [255, 99, 164], [92, 0, 73], [133, 129, 255], [78, 180, 255], [0, 228, 0], [174, 255, 243], [45, 89, 255], [134, 134, 103], [145, 148, 174], [255, 208, 186], [197, 226, 255], [171, 134, 1], [109, 63, 54], [207, 138, 255], [151, 0, 95], [9, 80, 61], [84, 105, 51], [74, 65, 105], [166, 196, 102], [208, 195, 210], [255, 109, 65], [0, 143, 149], [179, 0, 194], [209, 99, 106], [5, 121, 0], [227, 255, 205], [147, 186, 208], [153, 69, 1], [3, 95, 161], [163, 255, 0], [119, 0, 170], [0, 182, 199], [0, 165, 120], [183, 130, 88], [95, 32, 0], [130, 114, 135], [110, 129, 133], [166, 74, 118], [219, 142, 185], [79, 210, 114], [178, 90, 62], [65, 70, 15], [127, 167, 115], [59, 105, 106], [142, 108, 45], [196, 172, 0], [95, 54, 80], [128, 76, 255], [201, 57, 1], [246, 0, 122], [191, 162, 208], [255, 255, 128], [147, 211, 203], [150, 100, 100], [168, 171, 172], [146, 112, 198], [210, 170, 100], [92, 136, 89], [218, 88, 184], [241, 129, 0], [217, 17, 255], [124, 74, 181], [70, 70, 70], [255, 228, 255], [154, 208, 0], [193, 0, 92], [76, 91, 113], [255, 180, 195], [106, 154, 176], [230, 150, 140], [60, 143, 255], [128, 64, 128], [92, 82, 55], [254, 212, 124], [73, 77, 174], [255, 160, 98], [255, 255, 255], [104, 84, 109], [169, 164, 131], [225, 199, 255], [137, 54, 74], [135, 158, 223], [7, 246, 231], [107, 255, 200], [58, 41, 149], [183, 121, 142], [255, 73, 97], [107, 142, 35], [190, 153, 153], [146, 139, 141], [70, 130, 180], [134, 199, 156], [209, 226, 140], [96, 36, 108], [96, 96, 96], [64, 170, 64], [152, 251, 152], [208, 229, 228], [206, 186, 171], [152, 161, 64], [116, 112, 0], [0, 114, 143], [102, 102, 156], [250, 141, 255]]
|
166 |
-
# part_palette = [(0,0,0), (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0), (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)]
|
167 |
-
part_palette = [[0, 0, 0], [220, 20, 60], [119, 11, 32], [0, 0, 142], [0, 0, 230], [106, 0, 228], [0, 60, 100], [0, 80, 100], [0, 0, 70], [0, 0, 192], [250, 170, 30], [100, 170, 30], [220, 220, 0], [175, 116, 175], [250, 0, 30], [165, 42, 42], [255, 77, 255], [0, 226, 252], [182, 182, 255], [0, 82, 0], [120, 166, 157], [110, 76, 0], [174, 57, 255], [199, 100, 0], [72, 0, 118], [255, 179, 240]]
|
168 |
-
hot_palette = [[0, 0, 0], [220, 20, 60], [119, 11, 32], [0, 0, 142], [0, 0, 230], [106, 0, 228], [0, 60, 100], [0, 80, 100], [0, 0, 70], [0, 0, 192], [250, 170, 30], [100, 170, 30], [220, 220, 0], [175, 116, 175], [250, 0, 30], [165, 42, 42], [255, 77, 255], [0, 226, 252]]
|
169 |
-
|
170 |
-
sem_mask_gt = np.transpose(sem_mask_gt[0], (1, 2, 0))*255
|
171 |
-
sem_mask_gt = sem_mask_gt.astype(np.uint8)
|
172 |
-
sem_mask_pred = np.transpose(sem_mask_pred[0], (1, 2, 0))*255
|
173 |
-
sem_mask_pred = sem_mask_pred.astype(np.uint8)
|
174 |
-
part_mask_gt = np.transpose(part_mask_gt[0], (1, 2, 0))*255
|
175 |
-
part_mask_gt = part_mask_gt.astype(np.uint8)
|
176 |
-
part_mask_pred = np.transpose(part_mask_pred[0], (1, 2, 0))*255
|
177 |
-
part_mask_pred = part_mask_pred.astype(np.uint8)
|
178 |
-
contact_2d_gt_rgb = contact_2d_gt_rgb[0]*255
|
179 |
-
contact_2d_gt_rgb = contact_2d_gt_rgb.astype(np.uint8)
|
180 |
-
contact_2d_pred_rgb = contact_2d_pred_rgb[0]*255
|
181 |
-
contact_2d_pred_rgb = contact_2d_pred_rgb.astype(np.uint8)
|
182 |
-
|
183 |
-
sem_mask_rgb = unsplit(sem_mask_gt, sem_palette)
|
184 |
-
sem_pred_rgb = unsplit(sem_mask_pred, sem_palette)
|
185 |
-
part_mask_rgb = unsplit(part_mask_gt, part_palette)
|
186 |
-
part_pred_rgb = unsplit(part_mask_pred, part_palette)
|
187 |
-
|
188 |
-
sem_mask_rgb = sem_mask_rgb.astype(np.uint8)
|
189 |
-
sem_pred_rgb = sem_pred_rgb.astype(np.uint8)
|
190 |
-
part_mask_rgb = part_mask_rgb.astype(np.uint8)
|
191 |
-
part_pred_rgb = part_pred_rgb.astype(np.uint8)
|
192 |
-
|
193 |
-
sem_mask_rgb = pil_img.fromarray(sem_mask_rgb)
|
194 |
-
sem_pred_rgb = pil_img.fromarray(sem_pred_rgb)
|
195 |
-
part_mask_rgb = pil_img.fromarray(part_mask_rgb)
|
196 |
-
part_pred_rgb = pil_img.fromarray(part_pred_rgb)
|
197 |
-
contact_2d_gt_rgb = pil_img.fromarray(contact_2d_gt_rgb)
|
198 |
-
contact_2d_pred_rgb = pil_img.fromarray(contact_2d_pred_rgb)
|
199 |
-
|
200 |
-
tot_rend = pil_img.new('RGB', (3000, 2000))
|
201 |
-
tot_rend.paste(gt_rend, (0, 0))
|
202 |
-
tot_rend.paste(pred_rend, (0, 450))
|
203 |
-
tot_rend.paste(sem_mask_rgb, (0, 900))
|
204 |
-
tot_rend.paste(sem_pred_rgb, (400, 900))
|
205 |
-
tot_rend.paste(part_mask_rgb, (0, 1300))
|
206 |
-
tot_rend.paste(part_pred_rgb, (400, 1300))
|
207 |
-
tot_rend.paste(contact_2d_gt_rgb, (0, 1700))
|
208 |
-
tot_rend.paste(contact_2d_pred_rgb, (400, 1700))
|
209 |
-
return tot_rend
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|