Kai422kx's picture
init
4f6b78d
raw
history blame
14.4 kB
import torch
import numpy as np
import cv2
import glob
from pathlib import Path
from tqdm import tqdm
from dust3r.image_pairs import make_pairs
from dust3r.inference import inference
from dust3r.utils.image import load_images, rgb, enlarge_seg_masks
from copy import deepcopy
from scipy.optimize import minimize
import os
from collections import defaultdict
import dust3r.eval_metadata
from dust3r.eval_metadata import dataset_metadata
def eval_mono_depth_estimation(args, model, device):
metadata = dataset_metadata.get(args.eval_dataset)
if metadata is None:
raise ValueError(f"Unknown dataset: {args.eval_dataset}")
img_path = metadata.get('img_path')
if 'img_path_func' in metadata:
img_path = metadata['img_path_func'](args)
process_func = metadata.get('process_func')
if process_func is None:
raise ValueError(f"No processing function defined for dataset: {args.eval_dataset}")
for filelist, save_dir in process_func(args, img_path):
Path(save_dir).mkdir(parents=True, exist_ok=True)
eval_mono_depth(args, model, device, filelist, save_dir=save_dir)
def eval_mono_depth(args, model, device, filelist, save_dir=None):
model.eval()
load_img_size = 512
for file in tqdm(filelist):
# construct the "image pair" for the single image
file = [file]
imgs = load_images(file, size=load_img_size, verbose=False, crop= not args.no_crop)
imgs = [imgs[0], deepcopy(imgs[0])]
imgs[1]['idx'] = 1
pairs = make_pairs(imgs, symmetrize=True, prefilter=None)
output = inference(pairs, model, device, batch_size=1, verbose=False)
depth_map = output['pred1']['pts3d'][...,-1].mean(dim=0)
if save_dir is not None:
#save the depth map to the save_dir as npy
np.save(f"{save_dir}/{file[0].split('/')[-1].replace('.png','depth.npy')}", depth_map.cpu().numpy())
# also save the png
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
depth_map = (depth_map * 255).cpu().numpy().astype(np.uint8)
cv2.imwrite(f"{save_dir}/{file[0].split('/')[-1].replace('.png','depth.png')}", depth_map)
## used for calculating the depth evaluation metrics
def group_by_directory(pathes, idx=-1):
"""
Groups the file paths based on the second-to-last directory in their paths.
Parameters:
- pathes (list): List of file paths.
Returns:
- dict: A dictionary where keys are the second-to-last directory names and values are lists of file paths.
"""
grouped_pathes = defaultdict(list)
for path in pathes:
# Extract the second-to-last directory
dir_name = os.path.dirname(path).split('/')[idx]
grouped_pathes[dir_name].append(path)
return grouped_pathes
def depth2disparity(depth, return_mask=False):
if isinstance(depth, torch.Tensor):
disparity = torch.zeros_like(depth)
elif isinstance(depth, np.ndarray):
disparity = np.zeros_like(depth)
non_negtive_mask = depth > 0
disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
if return_mask:
return disparity, non_negtive_mask
else:
return disparity
def absolute_error_loss(params, predicted_depth, ground_truth_depth):
s, t = params
predicted_aligned = s * predicted_depth + t
abs_error = np.abs(predicted_aligned - ground_truth_depth)
return np.sum(abs_error)
def absolute_value_scaling(predicted_depth, ground_truth_depth, s=1, t=0):
predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1)
ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1)
initial_params = [s, t] # s = 1, t = 0
result = minimize(absolute_error_loss, initial_params, args=(predicted_depth_np, ground_truth_depth_np))
s, t = result.x
return s, t
def absolute_value_scaling2(predicted_depth, ground_truth_depth, s_init=1.0, t_init=0.0, lr=1e-4, max_iters=1000, tol=1e-6):
# Initialize s and t as torch tensors with requires_grad=True
s = torch.tensor([s_init], requires_grad=True, device=predicted_depth.device, dtype=predicted_depth.dtype)
t = torch.tensor([t_init], requires_grad=True, device=predicted_depth.device, dtype=predicted_depth.dtype)
optimizer = torch.optim.Adam([s, t], lr=lr)
prev_loss = None
for i in range(max_iters):
optimizer.zero_grad()
# Compute predicted aligned depth
predicted_aligned = s * predicted_depth + t
# Compute absolute error
abs_error = torch.abs(predicted_aligned - ground_truth_depth)
# Compute loss
loss = torch.sum(abs_error)
# Backpropagate
loss.backward()
# Update parameters
optimizer.step()
# Check convergence
if prev_loss is not None and torch.abs(prev_loss - loss) < tol:
break
prev_loss = loss.item()
return s.detach().item(), t.detach().item()
def depth_evaluation(predicted_depth_original, ground_truth_depth_original, max_depth=80, custom_mask=None, post_clip_min=None, post_clip_max=None, pre_clip_min=None, pre_clip_max=None,
align_with_lstsq=False, align_with_lad=False, align_with_lad2=False, lr=1e-4, max_iters=1000, use_gpu=False, align_with_scale=False,
disp_input=False):
"""
Evaluate the depth map using various metrics and return a depth error parity map, with an option for least squares alignment.
Args:
predicted_depth (numpy.ndarray or torch.Tensor): The predicted depth map.
ground_truth_depth (numpy.ndarray or torch.Tensor): The ground truth depth map.
max_depth (float): The maximum depth value to consider. Default is 80 meters.
align_with_lstsq (bool): If True, perform least squares alignment of the predicted depth with ground truth.
Returns:
dict: A dictionary containing the evaluation metrics.
torch.Tensor: The depth error parity map.
"""
if isinstance(predicted_depth_original, np.ndarray):
predicted_depth_original = torch.from_numpy(predicted_depth_original)
if isinstance(ground_truth_depth_original, np.ndarray):
ground_truth_depth_original = torch.from_numpy(ground_truth_depth_original)
if custom_mask is not None and isinstance(custom_mask, np.ndarray):
custom_mask = torch.from_numpy(custom_mask)
# if the dimension is 3, flatten to 2d along the batch dimension
if predicted_depth_original.dim() == 3:
_, h, w = predicted_depth_original.shape
predicted_depth_original = predicted_depth_original.view(-1, w)
ground_truth_depth_original = ground_truth_depth_original.view(-1, w)
if custom_mask is not None:
custom_mask = custom_mask.view(-1, w)
# put to device
if use_gpu:
predicted_depth_original = predicted_depth_original.cuda()
ground_truth_depth_original = ground_truth_depth_original.cuda()
# Filter out depths greater than max_depth
if max_depth is not None:
mask = (ground_truth_depth_original > 0) & (ground_truth_depth_original < max_depth)
else:
mask = (ground_truth_depth_original > 0)
predicted_depth = predicted_depth_original[mask]
ground_truth_depth = ground_truth_depth_original[mask]
# Clip the depth values
if pre_clip_min is not None:
predicted_depth = torch.clamp(predicted_depth, min=pre_clip_min)
if pre_clip_max is not None:
predicted_depth = torch.clamp(predicted_depth, max=pre_clip_max)
if disp_input: # align the pred to gt in the disparity space
real_gt = ground_truth_depth.clone()
ground_truth_depth = 1 / (ground_truth_depth + 1e-8)
# various alignment methods
if align_with_lstsq:
# Convert to numpy for lstsq
predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1, 1)
ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1, 1)
# Add a column of ones for the shift term
A = np.hstack([predicted_depth_np, np.ones_like(predicted_depth_np)])
# Solve for scale (s) and shift (t) using least squares
result = np.linalg.lstsq(A, ground_truth_depth_np, rcond=None)
s, t = result[0][0], result[0][1]
# convert to torch tensor
s = torch.tensor(s, device=predicted_depth_original.device)
t = torch.tensor(t, device=predicted_depth_original.device)
# Apply scale and shift
predicted_depth = s * predicted_depth + t
elif align_with_lad:
s, t = absolute_value_scaling(predicted_depth, ground_truth_depth, s=torch.median(ground_truth_depth) / torch.median(predicted_depth))
predicted_depth = s * predicted_depth + t
elif align_with_lad2:
s_init = (torch.median(ground_truth_depth) / torch.median(predicted_depth)).item()
s, t = absolute_value_scaling2(predicted_depth, ground_truth_depth, s_init=s_init, lr=lr, max_iters=max_iters)
predicted_depth = s * predicted_depth + t
elif align_with_scale:
# Compute initial scale factor 's' using the closed-form solution (L2 norm)
dot_pred_gt = torch.nanmean(ground_truth_depth)
dot_pred_pred = torch.nanmean(predicted_depth)
s = dot_pred_gt / dot_pred_pred
# Iterative reweighted least squares using the Weiszfeld method
for _ in range(10):
# Compute residuals between scaled predictions and ground truth
residuals = s * predicted_depth - ground_truth_depth
abs_residuals = residuals.abs() + 1e-8 # Add small constant to avoid division by zero
# Compute weights inversely proportional to the residuals
weights = 1.0 / abs_residuals
# Update 's' using weighted sums
weighted_dot_pred_gt = torch.sum(weights * predicted_depth * ground_truth_depth)
weighted_dot_pred_pred = torch.sum(weights * predicted_depth ** 2)
s = weighted_dot_pred_gt / weighted_dot_pred_pred
# Optionally clip 's' to prevent extreme scaling
s = s.clamp(min=1e-3)
# Detach 's' if you want to stop gradients from flowing through it
s = s.detach()
# Apply the scale factor to the predicted depth
predicted_depth = s * predicted_depth
else:
# Align the predicted depth with the ground truth using median scaling
scale_factor = torch.median(ground_truth_depth) / torch.median(predicted_depth)
predicted_depth *= scale_factor
if disp_input:
# convert back to depth
ground_truth_depth = real_gt
predicted_depth = depth2disparity(predicted_depth)
# Clip the predicted depth values
if post_clip_min is not None:
predicted_depth = torch.clamp(predicted_depth, min=post_clip_min)
if post_clip_max is not None:
predicted_depth = torch.clamp(predicted_depth, max=post_clip_max)
if custom_mask is not None:
assert custom_mask.shape == ground_truth_depth_original.shape
mask_within_mask = custom_mask.cpu()[mask]
predicted_depth = predicted_depth[mask_within_mask]
ground_truth_depth = ground_truth_depth[mask_within_mask]
# Calculate the metrics
abs_rel = torch.mean(torch.abs(predicted_depth - ground_truth_depth) / ground_truth_depth).item()
sq_rel = torch.mean(((predicted_depth - ground_truth_depth) ** 2) / ground_truth_depth).item()
# Correct RMSE calculation
rmse = torch.sqrt(torch.mean((predicted_depth - ground_truth_depth) ** 2)).item()
# Clip the depth values to avoid log(0)
predicted_depth = torch.clamp(predicted_depth, min=1e-5)
log_rmse = torch.sqrt(torch.mean((torch.log(predicted_depth) - torch.log(ground_truth_depth)) ** 2)).item()
# Calculate the accuracy thresholds
max_ratio = torch.maximum(predicted_depth / ground_truth_depth, ground_truth_depth / predicted_depth)
threshold_1 = torch.mean((max_ratio < 1.25).float()).item()
threshold_2 = torch.mean((max_ratio < 1.25 ** 2).float()).item()
threshold_3 = torch.mean((max_ratio < 1.25 ** 3).float()).item()
# Compute the depth error parity map
if align_with_lstsq or align_with_lad or align_with_lad2:
predicted_depth_original = predicted_depth_original * s + t
if disp_input: predicted_depth_original = depth2disparity(predicted_depth_original)
depth_error_parity_map = torch.abs(predicted_depth_original - ground_truth_depth_original) / ground_truth_depth_original
elif align_with_scale:
predicted_depth_original = predicted_depth_original * s
if disp_input: predicted_depth_original = depth2disparity(predicted_depth_original)
depth_error_parity_map = torch.abs(predicted_depth_original - ground_truth_depth_original) / ground_truth_depth_original
else:
predicted_depth_original = predicted_depth_original * scale_factor
if disp_input: predicted_depth_original = depth2disparity(predicted_depth_original)
depth_error_parity_map = torch.abs(predicted_depth_original - ground_truth_depth_original) / ground_truth_depth_original
# Reshape the depth_error_parity_map back to the original image size
depth_error_parity_map_full = torch.zeros_like(ground_truth_depth_original)
depth_error_parity_map_full = torch.where(mask, depth_error_parity_map, depth_error_parity_map_full)
predict_depth_map_full = predicted_depth_original
gt_depth_map_full = torch.zeros_like(ground_truth_depth_original)
gt_depth_map_full = torch.where(mask, ground_truth_depth_original, gt_depth_map_full)
num_valid_pixels = torch.sum(mask).item() if custom_mask is None else torch.sum(mask_within_mask).item()
if num_valid_pixels == 0:
abs_rel, sq_rel, rmse, log_rmse, threshold_1, threshold_2, threshold_3 = 0, 0, 0, 0, 0, 0, 0
results = {
'Abs Rel': abs_rel,
'Sq Rel': sq_rel,
'RMSE': rmse,
'Log RMSE': log_rmse,
'δ < 1.25': threshold_1,
'δ < 1.25^2': threshold_2,
'δ < 1.25^3': threshold_3,
'valid_pixels': num_valid_pixels
}
return results, depth_error_parity_map_full, predict_depth_map_full, gt_depth_map_full