Kai422kx's picture
init
4f6b78d
raw
history blame
9.69 kB
# --------------------------------------------------------
# utilities needed for the inference
# --------------------------------------------------------
import tqdm
import torch
from dust3r.utils.device import to_cpu, collate_with_cat
from dust3r.utils.misc import invalid_to_nans
from dust3r.utils.geometry import depthmap_to_pts3d, geotrf
from dust3r.viz import SceneViz, auto_cam_size
from dust3r.utils.image import rgb
import numpy as np
import torch
from PIL import Image
def _interleave_imgs(img1, img2):
res = {}
for key, value1 in img1.items():
value2 = img2[key]
if isinstance(value1, torch.Tensor) and value1.ndim == value2.ndim:
value = torch.stack((value1, value2), dim=1).flatten(0, 1)
else:
value = [x for pair in zip(value1, value2) for x in pair]
res[key] = value
return res
def make_batch_symmetric(batch):
view1, view2 = batch
view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1))
return view1, view2
def mask_to_color(mask):
colors = np.zeros((*mask.shape, 3))
colors[:,:,0] = mask.cpu().detach() # Green channel weighted by mmask
return colors
def visualize_results_mmask(view1, view2, pred1, pred2, save_dir='./tmp', save_name=None, visualize_type='gt'):
# visualize_type: 'gt' or 'pred'
viz1 = SceneViz()
viz2 = SceneViz()
viz = [viz1, viz2]
views = [view1, view2]
poses = [views[view_idx]['camera_pose'][0] for view_idx in [0, 1]]
cam_size = max(auto_cam_size(poses), 0.5)
if visualize_type == 'pred':
cam_size *= 0.1
views[0]['pts3d'] = geotrf(poses[0], pred1['pts3d']) # convert from X_camera1 to X_world
views[1]['pts3d'] = geotrf(poses[0], pred2['pts3d_in_other_view'])
mmask = [pred1['dynamic_mask'], pred2['dynamic_mask']]
else:
mmask = [view1['dynamic_mask'], view2['dynamic_mask']]
images = []
save_paths = []
for view_idx in [0, 1]:
pts3d = views[view_idx]['pts3d'][0]
valid_mask = views[view_idx]['valid_mask'][0]
colors = rgb(views[view_idx]['img'][0])
alpha = 0.5 # You can adjust the alpha value as needed
mmask_color = mask_to_color(mmask[view_idx][0])
colors = alpha * colors + (1 - alpha) * mmask_color
images.append(colors)
# viz[view_idx].add_pointcloud(pts3d, colors, valid_mask)
# # viz.add_camera(pose_c2w=views[view_idx]['camera_pose'][0],
# # focal=views[view_idx]['camera_intrinsics'][0, 0],
# # color=(255, 0, 0),
# # image=colors,
# # cam_size=cam_size)
save_name = f'{views[0]["dataset"][0]}_{views[0]["label"][0]}_{views[0]["instance"][0]}_{views[1]["instance"][0]}_{visualize_type}_{view_idx}'
# save_path = save_dir+'/'+save_name+'_mmask.glb'
# # print(f'Saving visualization to {save_path}')
# viz[view_idx].save_glb(save_path)
# save_paths.append(save_path)
# Save the RGB image multiplied by 255 to a file
rgb_image = (colors * 255).astype(np.uint8)
img = Image.fromarray(rgb_image)
img.save(save_dir+'/'+save_name+'_mmask.png')
return images[0], images[1]
def visualize_results(view1, view2, pred1, pred2, save_dir='./tmp', save_name=None, visualize_type='gt'):
# visualize_type: 'gt' or 'pred'
viz1 = SceneViz()
viz2 = SceneViz()
viz = [viz1, viz2]
views = [view1, view2]
poses = [views[view_idx]['camera_pose'][0] for view_idx in [0, 1]]
cam_size = max(auto_cam_size(poses), 0.5)
if visualize_type == 'pred':
cam_size *= 0.1
views[0]['pts3d'] = geotrf(poses[0], pred1['pts3d']) # convert from X_camera1 to X_world
views[1]['pts3d'] = geotrf(poses[0], pred2['pts3d_in_other_view'])
save_paths = []
images = []
for view_idx in [0, 1]:
pts3d = views[view_idx]['pts3d'][0]
valid_mask = views[view_idx]['valid_mask'][0]
colors = rgb(views[view_idx]['img'][0])
images.append(colors)
# viz[view_idx].add_pointcloud(pts3d, colors, valid_mask)
# viz[view_idx].add_camera(pose_c2w=views[view_idx]['camera_pose'][0],
# focal=views[view_idx]['camera_intrinsics'][0, 0],
# color=(255, 0, 0),
# image=colors,
# cam_size=cam_size)
if save_name is None:
save_name = f'{views[0]["dataset"][0]}_{views[0]["label"][0]}_{views[0]["instance"][0]}_{views[1]["instance"][0]}_{visualize_type}_{view_idx}'
# save_path = save_dir+'/'+save_name+'.glb'
# # print(f'Saving visualization to {save_path}')
# viz[view_idx].save_glb(save_path)
# save_paths.append(save_path)
# Save the RGB image multiplied by 255 to a file
rgb_image = (colors * 255).astype(np.uint8)
img = Image.fromarray(rgb_image)
img.save(save_dir+'/'+save_name+'.png')
return images[0], images[1]
def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None):
view1, view2 = batch
ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'rng'])
for view in batch:
for name in view.keys(): # pseudo_focal
if name in ignore_keys:
continue
view[name] = view[name].to(device, non_blocking=True)
if symmetrize_batch:
view1, view2 = make_batch_symmetric(batch)
with torch.amp.autocast(enabled=bool(use_amp), device_type="cuda"):
# Export the model
pred1, pred2 = model(view1, view2)
# loss is supposed to be symmetric
with torch.amp.autocast(enabled=False, device_type="cuda"):
loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None
result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)
return result[ret] if ret else result
@torch.no_grad()
def inference(pairs, model, device, batch_size=8, verbose=True):
if verbose:
print(f'>> Inference with model on {len(pairs)} image pairs')
result = []
# first, check if all images have the same size
multiple_shapes = not (check_if_same_size(pairs))
if multiple_shapes: # force bs=1
batch_size = 1
for i in tqdm.trange(0, len(pairs), batch_size, disable=not verbose):
res = loss_of_one_batch(collate_with_cat(pairs[i:i+batch_size]), model, None, device)
result.append(to_cpu(res))
result = collate_with_cat(result, lists=multiple_shapes)
return result
def check_if_same_size(pairs):
shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs]
shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs]
return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2)
def get_pred_pts3d(gt, pred, use_pose=False):
if 'depth' in pred and 'pseudo_focal' in pred:
try:
pp = gt['camera_intrinsics'][..., :2, 2]
except KeyError:
pp = None
pts3d = depthmap_to_pts3d(**pred, pp=pp)
elif 'pts3d' in pred:
# pts3d from my camera
pts3d = pred['pts3d']
elif 'pts3d_in_other_view' in pred:
# pts3d from the other camera, already transformed
assert use_pose is True
return pred['pts3d_in_other_view'] # return!
if use_pose:
camera_pose = pred.get('camera_pose')
assert camera_pose is not None
pts3d = geotrf(camera_pose, pts3d)
return pts3d
def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None):
assert gt_pts1.ndim == pr_pts1.ndim == 4
assert gt_pts1.shape == pr_pts1.shape
if gt_pts2 is not None:
assert gt_pts2.ndim == pr_pts2.ndim == 4
assert gt_pts2.shape == pr_pts2.shape
# concat the pointcloud
nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2)
nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None
pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2)
pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None
all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1
all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1
dot_gt_pr = (all_pr * all_gt).sum(dim=-1)
dot_gt_gt = all_gt.square().sum(dim=-1)
if fit_mode.startswith('avg'):
# scaling = (all_pr / all_gt).view(B, -1).mean(dim=1)
scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
elif fit_mode.startswith('median'):
scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values
elif fit_mode.startswith('weiszfeld'):
# init scaling with l2 closed form
scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
# iterative re-weighted least-squares
for iter in range(10):
# re-weighting by inverse of distance
dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1)
# print(dis.nanmean(-1))
w = dis.clip_(min=1e-8).reciprocal()
# update the scaling with the new weights
scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1)
else:
raise ValueError(f'bad {fit_mode=}')
if fit_mode.endswith('stop_grad'):
scaling = scaling.detach()
scaling = scaling.clip(min=1e-3)
# assert scaling.isfinite().all(), bb()
return scaling