alexnasa commited on
Commit
4522fd6
·
verified ·
1 Parent(s): 7a02f69

Update src/pixel3dmm/tracking/tracker.py

Browse files
Files changed (1) hide show
  1. src/pixel3dmm/tracking/tracker.py +15 -45
src/pixel3dmm/tracking/tracker.py CHANGED
@@ -1279,6 +1279,18 @@ class Tracker(object):
1279
  verts_depth=proj_vertices[:, :, 2:3],
1280
  is_viz=True
1281
  )
 
 
 
 
 
 
 
 
 
 
 
 
1282
  mask = (self.parse_mask(ops, batch, visualization=True) > 0).float()
1283
  grabbed_depth = ops['actual_rendered_depth'][0, 0,
1284
  torch.clamp(proj_vertices[0, :, 1].long(), 0, self.config.size-1),
@@ -1289,49 +1301,6 @@ class Tracker(object):
1289
  is_visible_verts_idx = torch.ones_like(is_visible_verts_idx)
1290
 
1291
 
1292
- all_final_views = []
1293
- for b_i in range(bs):
1294
- final_views = []
1295
-
1296
- for views in visualizations:
1297
- row = []
1298
- for view in views:
1299
- if view == View.COLOR_OVERLAY:
1300
- row.append((ops['normal_images'][b_i].cpu().numpy() + 1)/2)
1301
- if view == View.GROUND_TRUTH:
1302
- row.append(images[b_i].cpu().numpy())
1303
- if (view == View.LANDMARKS and not self.no_lm) or is_camera:
1304
- gt_lmks = images[b_i:b_i+1].clone()
1305
- gt_lmks = util.tensor_vis_landmarks(gt_lmks, landmarks[b_i:b_i+1, :, :], color='g')
1306
- gt_lmks = util.tensor_vis_landmarks(gt_lmks, batch['left_iris'][b_i:b_i+1, ...], color='g')
1307
- gt_lmks = util.tensor_vis_landmarks(gt_lmks, batch['right_iris'][b_i:b_i+1, ...], color='g')
1308
- gt_lmks = util.tensor_vis_landmarks(gt_lmks, proj_vertices[b_i:b_i+1, left_iris_flame, ...], color='r')
1309
- gt_lmks = util.tensor_vis_landmarks(gt_lmks, proj_vertices[b_i:b_i+1, right_iris_flame, ...], color='r')
1310
- gt_lmks = util.tensor_vis_landmarks(gt_lmks, lmk68[b_i:b_i+1, :, :], color='r')
1311
- row.append(gt_lmks[0].cpu().numpy())
1312
-
1313
- if True:
1314
- nvd_mask = gaussian_blur(ops['mask_images_rendering'].detach(),
1315
- kernel_size=[self.config.normal_mask_ksize, self.config.normal_mask_ksize],
1316
- sigma=[self.config.normal_mask_ksize, self.config.normal_mask_ksize])
1317
- nvd_mask = (nvd_mask > 0.5).float()
1318
- nvd_mask_clone = nvd_mask.clone()
1319
-
1320
-
1321
- eyebrow_level = torch.min(lmk68[:, :, 1], dim=1).indices
1322
-
1323
- for _i in range(eyebrow_level.shape[0]):
1324
- nvd_mask_clone[_i, :, :eyebrow_level[_i], :] = 0
1325
-
1326
-
1327
- final_views.append(row)
1328
-
1329
-
1330
- # VIDEO
1331
- final_views = util.merge_views(final_views)
1332
- all_final_views.append(final_views)
1333
- final_views = np.concatenate(all_final_views, axis=0)
1334
-
1335
  if outer_iter is None:
1336
  frame_id = str(self.frame).zfill(5)
1337
  else:
@@ -1714,8 +1683,9 @@ class Tracker(object):
1714
  batches = {k: torch.cat([x[k] for x in batches], dim=0) for k in batch.keys()}
1715
  selected_frames = torch.from_numpy(np.array(selected_frames)).long().cuda()
1716
 
1717
- result_rendering = self.render_and_save(batches, visualizations=[[View.SHAPE]],
1718
- frame_dst='/joint_initialization', outer_iter=0, timestep=timestep, is_final=True, selected_frames=selected_frames)
 
1719
  video_frames.append(np.array(result_rendering))
1720
  self.frame += 1
1721
 
 
1279
  verts_depth=proj_vertices[:, :, 2:3],
1280
  is_viz=True
1281
  )
1282
+ # if they asked *only* for the pure shape mask:
1283
+ if visualizations == [[View.SHAPE]]:
1284
+ # ops['normal_images'] is [1,3,H,W] in world‐space ∈ [-1,1]
1285
+ normals = ops['normal_images'][0].cpu().numpy() # [3,H,W]
1286
+ # remap to [0,1]
1287
+ normals = (normals + 1.0) / 2.0
1288
+ # H×W×3
1289
+ normals = np.transpose(normals, (1, 2, 0))
1290
+ # scale to uint8
1291
+ arr = (normals * 255).clip(0,255).astype(np.uint8)
1292
+ return arr
1293
+
1294
  mask = (self.parse_mask(ops, batch, visualization=True) > 0).float()
1295
  grabbed_depth = ops['actual_rendered_depth'][0, 0,
1296
  torch.clamp(proj_vertices[0, :, 1].long(), 0, self.config.size-1),
 
1301
  is_visible_verts_idx = torch.ones_like(is_visible_verts_idx)
1302
 
1303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1304
  if outer_iter is None:
1305
  frame_id = str(self.frame).zfill(5)
1306
  else:
 
1683
  batches = {k: torch.cat([x[k] for x in batches], dim=0) for k in batch.keys()}
1684
  selected_frames = torch.from_numpy(np.array(selected_frames)).long().cuda()
1685
 
1686
+ result_rendering = self.render_and_save(batch,
1687
+ visualizations=[[View.SHAPE]], # only mesh by default
1688
+ frame_dst='/video', save=True, dump_directly=False, outer_iter=0, timestep=timestep, is_final=True, selected_frames=selected_frames)
1689
  video_frames.append(np.array(result_rendering))
1690
  self.frame += 1
1691