|
from utils.utils_3d import Network_3D |
|
from utils.utils_2d import Network_2D, load_yaml |
|
import time |
|
import torch |
|
import os |
|
import os.path as osp |
|
import imageio |
|
import glob |
|
import open3d as o3d |
|
import numpy as np |
|
import math |
|
from models.Mask3D.mask3d import load_mesh, load_ply |
|
import colorsys |
|
from tqdm import tqdm |
|
|
|
def get_iou(masks): |
|
masks = masks.float() |
|
intersection = torch.einsum('ij,kj -> ik', masks, masks) |
|
num_masks = masks.shape[0] |
|
masks_batch_size = 2 |
|
if masks_batch_size < num_masks: |
|
ratio = num_masks//masks_batch_size |
|
remaining = num_masks-ratio*masks_batch_size |
|
start_masks = list(range(0,ratio*masks_batch_size, masks_batch_size)) |
|
if remaining == 0: |
|
end_masks = list(range(masks_batch_size,(ratio+1)*masks_batch_size,masks_batch_size)) |
|
else: |
|
end_masks = list(range(masks_batch_size,(ratio+1)*masks_batch_size,masks_batch_size)) |
|
end_masks[-1] = num_masks |
|
else: |
|
start_masks = [0] |
|
end_masks = [num_masks] |
|
union = torch.cat([((masks[st:ed, None, :]+masks[None, :, :]) >= 1).sum(-1) for st,ed in zip(start_masks, end_masks)]) |
|
iou = torch.div(intersection,union) |
|
|
|
return iou |
|
|
|
def apply_nms(masks, scores, nms_th): |
|
masks = masks.permute(1,0) |
|
scored_sorted, sorted_scores_indices = torch.sort(scores, descending=True) |
|
inv_sorted_scores_indices = {sorted_id.item(): id for id, sorted_id in enumerate(sorted_scores_indices)} |
|
maskes_sorted = masks[sorted_scores_indices] |
|
iou = get_iou(maskes_sorted) |
|
available_indices = torch.arange(len(scored_sorted)) |
|
for indx in range(len(available_indices)): |
|
remove_indices = torch.where(iou[indx,indx+1:] > nms_th)[0] |
|
available_indices[indx+1:][remove_indices] = 0 |
|
remaining = available_indices.unique() |
|
keep_indices = torch.tensor([inv_sorted_scores_indices[id.item()] for id in remaining]) |
|
return keep_indices |
|
|
|
def generate_vibrant_colors(num_colors): |
|
colors = [] |
|
hue_increment = 1.0 / num_colors |
|
saturation = 1.0 |
|
value = 1.0 |
|
|
|
for i in range(num_colors): |
|
hue = i * hue_increment |
|
rgb = colorsys.hsv_to_rgb(hue, saturation, value) |
|
colors.append(rgb) |
|
|
|
return colors |
|
|
|
def get_visibility_mat(pred_masks_3d, inside_mask, topk = 15): |
|
intersection = torch.einsum("ik, fk -> if", pred_masks_3d.float(), inside_mask.float()) |
|
total_point_number = pred_masks_3d[:, None, :].float().sum(dim = -1) |
|
visibility_matrix = intersection/total_point_number |
|
|
|
if topk > visibility_matrix.shape[-1]: |
|
topk = visibility_matrix.shape[-1] |
|
|
|
max_visiblity_in_frame = torch.topk(visibility_matrix, topk, dim = -1).indices |
|
|
|
visibility_matrix_bool = torch.zeros_like(visibility_matrix).bool() |
|
visibility_matrix_bool[torch.tensor(range(len(visibility_matrix_bool)))[:, None],max_visiblity_in_frame] = True |
|
|
|
return visibility_matrix_bool |
|
|
|
def compute_iou(box, boxes): |
|
assert box.shape == (4,), "Reference box must be of shape (4,)" |
|
assert boxes.shape[1] == 4, "Boxes must be of shape (N, 4)" |
|
|
|
x1_inter = torch.max(box[0], boxes[:, 0]) |
|
y1_inter = torch.max(box[1], boxes[:, 1]) |
|
x2_inter = torch.min(box[2], boxes[:, 2]) |
|
y2_inter = torch.min(box[3], boxes[:, 3]) |
|
inter_area = (x2_inter - x1_inter).clamp(0) * (y2_inter - y1_inter).clamp(0) |
|
box_area = (box[2] - box[0]) * (box[3] - box[1]) |
|
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) |
|
union_area = box_area + boxes_area - inter_area |
|
iou = inter_area / union_area |
|
|
|
return iou |
|
|
|
class OpenYolo3D(): |
|
def __init__(self, openyolo3d_config = ""): |
|
config = load_yaml(openyolo3d_config) |
|
self.network_3d = Network_3D(config) |
|
self.network_2d = Network_2D(config) |
|
self.openyolo3d_config = config |
|
|
|
def predict(self, path_2_scene_data, depth_scale, processed_scene = None, path_to_3d_masks = None, is_gt=False): |
|
self.world2cam = WORLD_2_CAM(path_2_scene_data, depth_scale, self.openyolo3d_config) |
|
self.mesh_projections = self.world2cam.get_mesh_projections() |
|
self.scaling_params = [self.world2cam.depth_resolution[0]/self.world2cam.image_resolution[0], self.world2cam.depth_resolution[1]/self.world2cam.image_resolution[1]] |
|
|
|
scene_name = path_2_scene_data.split("/")[-1] |
|
print("[π ACTION] 3D mask proposals computation ...") |
|
start = time.time() |
|
|
|
if path_to_3d_masks is None: |
|
self.preds_3d = self.network_3d.get_class_agnostic_masks(self.world2cam.mesh) if processed_scene is None else self.network_3d.get_class_agnostic_masks(processed_scene) |
|
keep_score = self.preds_3d[1] >= self.openyolo3d_config["network3d"]["th"] |
|
keep_nms = apply_nms(self.preds_3d[0][:, keep_score].cuda(), self.preds_3d[1][keep_score].cuda(), self.openyolo3d_config["network3d"]["nms"]) |
|
self.preds_3d = (self.preds_3d[0].cpu().permute(1,0)[keep_score][keep_nms].permute(1,0), self.preds_3d[1].cpu()[keep_score][keep_nms]) |
|
else: |
|
self.preds_3d = torch.load(osp.join(path_to_3d_masks, f"{scene_name}.pt")) |
|
|
|
print(f"[π INFO] Elapsed time {(time.time()-start)}") |
|
print(f"[β
INFO] Proposals computed.") |
|
|
|
print("[π ACTION] 2D Bounding Boxes computation ...") |
|
start = time.time() |
|
self.preds_2d = self.network_2d.get_bounding_boxes(self.world2cam.color_paths) |
|
|
|
print(f"[π INFO] Elapsed time {(time.time()-start)}") |
|
print(f"[β
INFO] Bounding boxes computed.") |
|
|
|
print("[π ACTION] Predicting ...") |
|
start = time.time() |
|
prediction = self.label_3d_masks_from_2d_bboxes(scene_name, is_gt) |
|
print(f"[π INFO] Elapsed time {(time.time()-start)}") |
|
print(f"[β
INFO] Prediction completed") |
|
|
|
return prediction |
|
|
|
def label_3d_masks_from_2d_bboxes(self, scene_name, is_gt=False): |
|
projections_mesh_to_frame , keep_visible_points = self.mesh_projections |
|
predictions_2d_bboxes = self.preds_2d |
|
prediction_3d_masks, _ = self.preds_3d |
|
|
|
predicted_masks, predicated_classes, predicated_scores = self.label_3d_masks_from_label_maps(prediction_3d_masks.bool(), |
|
predictions_2d_bboxes, |
|
projections_mesh_to_frame, |
|
keep_visible_points, |
|
is_gt) |
|
|
|
self.predicted_masks = predicted_masks |
|
self.predicated_scores = predicated_scores |
|
self.predicated_classes = predicated_classes |
|
|
|
return {scene_name : (predicted_masks, predicated_classes, predicated_scores)} |
|
|
|
|
|
def label_3d_masks_from_label_maps(self, |
|
prediction_3d_masks, |
|
predictions_2d_bboxes, |
|
projections_mesh_to_frame, |
|
keep_visible_points, |
|
is_gt): |
|
|
|
label_maps = self.construct_label_maps(predictions_2d_bboxes) |
|
|
|
visibility_matrix = get_visibility_mat(prediction_3d_masks.cuda().permute(1,0), keep_visible_points.cuda(), topk = 25 if is_gt else self.openyolo3d_config["openyolo3d"]["topk"]) |
|
valid_frames = visibility_matrix.sum(dim=0) >= 1 |
|
|
|
prediction_3d_masks = prediction_3d_masks.permute(1,0).cpu() |
|
prediction_3d_masks_np = prediction_3d_masks.numpy() |
|
projections_mesh_to_frame = projections_mesh_to_frame[valid_frames].cpu().numpy() |
|
visibility_matrix = visibility_matrix[:, valid_frames].cpu().numpy() |
|
keep_visible_points = keep_visible_points[valid_frames].cpu().numpy() |
|
distributions = [] |
|
|
|
class_labels = [] |
|
class_probs = [] |
|
class_dists = [] |
|
label_maps = label_maps[valid_frames].numpy() |
|
bounding_boxes = predictions_2d_bboxes.values() |
|
bounding_boxes_valid = [bbox for (bi, bbox) in enumerate(bounding_boxes) if valid_frames[bi]] |
|
for mask_id, mask in enumerate(prediction_3d_masks_np): |
|
prob_normalizer = 0 |
|
|
|
representitive_frame_ids = np.where(visibility_matrix[mask_id])[0] |
|
labels_distribution = [] |
|
iou_vals = [] |
|
for representitive_frame_id in representitive_frame_ids: |
|
visible_points_mask = (keep_visible_points[representitive_frame_id].squeeze()*mask).astype(bool) |
|
prob_normalizer += visible_points_mask.sum() |
|
instance_x_y_coords = projections_mesh_to_frame[representitive_frame_id][np.where(visible_points_mask)].astype(np.int64) |
|
|
|
boxes = bounding_boxes_valid[representitive_frame_id]["bbox"].long() |
|
if len(boxes) > 0 and len(instance_x_y_coords > 10): |
|
x_l, x_r, y_t, y_b = instance_x_y_coords[:, 0].min(), instance_x_y_coords[:, 0].max()+1, instance_x_y_coords[:, 1].min(), instance_x_y_coords[:, 1].max()+1 |
|
box = torch.tensor([x_l, y_t, x_r, y_b]) |
|
|
|
iou_values = compute_iou(box, boxes) |
|
iou_vals.append(iou_values.max().item()) |
|
selected_labels = label_maps[representitive_frame_id, instance_x_y_coords[:, 1], instance_x_y_coords[:, 0]] |
|
labels_distribution.append(selected_labels) |
|
|
|
labels_distribution = np.concatenate(labels_distribution) if len(labels_distribution) > 0 else np.array([-1]) |
|
|
|
|
|
distribution = torch.zeros(self.openyolo3d_config["openyolo3d"]["num_classes"]) if self.openyolo3d_config["openyolo3d"]["topk_per_image"] != -1 else None |
|
if (labels_distribution != -1).sum() != 0: |
|
|
|
if distribution is not None: |
|
all_labels = torch.from_numpy(labels_distribution[labels_distribution != -1]) |
|
all_labels_unique = all_labels.unique() |
|
for lb in all_labels_unique: |
|
distribution[lb] = (all_labels == lb).sum() |
|
|
|
distribution = distribution/distribution.max() |
|
|
|
class_label = torch.mode(torch.from_numpy(labels_distribution[labels_distribution != -1])).values.item() |
|
class_prob = (labels_distribution == class_label).sum()/prob_normalizer |
|
else: |
|
if distribution is not None: |
|
distribution[-1] = 1.0 |
|
class_label = -1 |
|
class_prob = 0.0 |
|
|
|
iou_vals = torch.tensor(iou_vals) |
|
|
|
class_labels.append(class_label) |
|
if (iou_vals != 0).sum(): |
|
iou_prob = iou_vals[iou_vals != 0].mean().item() |
|
else: |
|
iou_prob = 0.0 |
|
|
|
class_probs.append(class_prob*iou_prob) |
|
if distribution is not None: |
|
distributions.append(distribution) |
|
|
|
pred_classes = torch.tensor(class_labels) |
|
pred_scores = torch.tensor(class_probs) |
|
if distribution is not None: |
|
distributions = torch.stack(distributions) if len(distributions) > 0 else torch.tensor((0, self.openyolo3d_config["openyolo3d"]["num_classes"])) |
|
|
|
if (self.openyolo3d_config["openyolo3d"]["topk_per_image"] != -1) and (not is_gt): |
|
|
|
n_instance = distributions.shape[0] |
|
distributions = distributions.reshape(-1) |
|
labels = ( |
|
torch.arange(self.openyolo3d_config["openyolo3d"]["num_classes"], device=distributions.device) |
|
.unsqueeze(0) |
|
.repeat(n_instance, 1) |
|
.flatten(0, 1) |
|
) |
|
|
|
cur_topk = self.openyolo3d_config["openyolo3d"]["topk_per_image"] |
|
_, idx = torch.topk(distributions, k=min(cur_topk, len(distributions)), largest=True) |
|
mask_idx = torch.div(idx, self.openyolo3d_config["openyolo3d"]["num_classes"], rounding_mode="floor") |
|
|
|
pred_classes = labels[idx] |
|
pred_scores = distributions[idx].cuda() |
|
prediction_3d_masks = prediction_3d_masks[mask_idx] |
|
|
|
return prediction_3d_masks.permute(1,0), pred_classes, pred_scores |
|
|
|
def construct_label_maps(self, predictions_2d_bboxes, save_label_map=False): |
|
label_maps = (torch.ones((len(predictions_2d_bboxes), self.world2cam.height, self.world2cam.width))*-1).type(torch.int16) |
|
for frame_id, pred in enumerate(predictions_2d_bboxes.values()): |
|
bboxes = pred["bbox"].long() |
|
labels = pred["labels"].type(torch.int16) |
|
|
|
bboxes[:,0] = bboxes[:,0]*self.scaling_params[1] |
|
bboxes[:,2] = bboxes[:,2]*self.scaling_params[1] |
|
bboxes[:,1] = bboxes[:,1]*self.scaling_params[0] |
|
bboxes[:,3] = bboxes[:,3]*self.scaling_params[0] |
|
bboxes_weights = (bboxes[:,2]-bboxes[:,0])+(bboxes[:,3]-bboxes[:,1]) |
|
sorted_indices = bboxes_weights.sort(descending=True).indices |
|
bboxes = bboxes[sorted_indices] |
|
labels = labels[sorted_indices] |
|
for id, bbox in enumerate(bboxes): |
|
label_maps[frame_id, bbox[1]:bbox[3],bbox[0]:bbox[2]] = labels[id] |
|
|
|
return label_maps |
|
|
|
def save_output_as_ply(self, save_path, highest_score = True): |
|
if highest_score : |
|
th = self.predicated_scores.max() |
|
else: |
|
th = self.predicated_scores.max()-0.1 |
|
|
|
mesh = load_mesh(self.world2cam.mesh) |
|
vertex_colors = np.asarray(mesh.vertex_colors) |
|
vibrant_colors = generate_vibrant_colors(len(self.predicated_scores[self.predicated_scores >= th])) |
|
color_id = 0 |
|
for i, class_id in enumerate(self.predicated_classes): |
|
if self.predicated_scores[i] < th: |
|
continue |
|
if len(vibrant_colors) == 0: |
|
break |
|
mask = self.predicted_masks.permute(1,0)[i] |
|
vertex_colors[mask] = np.array(vibrant_colors.pop()) |
|
color_id += 1 |
|
mesh.vertex_colors = o3d.utility.Vector3dVector(vertex_colors) |
|
o3d.io.write_triangle_mesh(save_path, mesh) |
|
|
|
|
|
|
|
class WORLD_2_CAM(): |
|
def __init__(self, path_2_scene, depth_scale, openyolo3d_config = None): |
|
self.poses = {} |
|
self.intrinsics = {} |
|
self.meshes = {} |
|
self.depth_maps_paths = {} |
|
self.depth_color_paths = {} |
|
self.vis_depth_threshold = openyolo3d_config["openyolo3d"]['vis_depth_threshold'] |
|
|
|
frequency = openyolo3d_config["openyolo3d"]['frequency'] |
|
|
|
path_2_poses = osp.join(path_2_scene,"poses") |
|
num_frames = len(os.listdir(path_2_poses)) |
|
self.poses = [osp.join(path_2_poses, f"{i}.txt") for i in list(range(num_frames))[::frequency]] |
|
|
|
path_2_intrinsics = osp.join(path_2_scene,"intrinsics.txt") |
|
self.intrinsics = [path_2_intrinsics for i in list(range(num_frames))[::frequency]] |
|
|
|
self.mesh = glob.glob(path_2_scene+"/*.ply")[0] |
|
|
|
path_2_depth = osp.join(path_2_scene,"depth") |
|
self.depth_maps_paths = [osp.join(path_2_depth, f"{i}.png") for i in list(range(num_frames))[::frequency]] |
|
|
|
path_2_color = osp.join(path_2_scene,"color") |
|
self.color_paths = [osp.join(path_2_color, f"{i}.jpg") for i in list(range(num_frames))[::frequency]] |
|
|
|
|
|
self.image_resolution = imageio.imread(list(self.color_paths)[0]).shape[:2] |
|
self.depth_resolution = imageio.imread(list(self.depth_maps_paths)[0]).shape |
|
self.height = self.depth_resolution[0] |
|
self.width = self.depth_resolution[1] |
|
|
|
self.depth_scale = depth_scale |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
@staticmethod |
|
def load_ply(path_2_mesh): |
|
pcd = o3d.io.read_point_cloud(path_2_mesh) |
|
points = np.asarray(pcd.points) |
|
colors = np.asarray(pcd.colors) |
|
|
|
coords = np.concatenate([points, np.ones((points.shape[0], 1))], axis = -1) |
|
return coords, colors |
|
|
|
def load_depth_maps(self): |
|
depth_maps = [] |
|
paths_to_depth_maps_scene_i = self.depth_maps_paths |
|
for depth_map_path_i in paths_to_depth_maps_scene_i: |
|
depth_path = os.path.join(depth_map_path_i) |
|
depth_maps.append(torch.from_numpy(imageio.imread(depth_path) / self.depth_scale).to(self.device)) |
|
return torch.stack(depth_maps) |
|
|
|
def adjust_intrinsic(self, intrinsic, original_resolution, new_resolution): |
|
if original_resolution == new_resolution: |
|
return intrinsic |
|
|
|
resize_width = int(math.floor(new_resolution[1] * float( |
|
original_resolution[0]) / float(original_resolution[1]))) |
|
|
|
adapted_intrinsic = intrinsic.copy() |
|
adapted_intrinsic[0, 0] *= float(resize_width) / float(original_resolution[0]) |
|
adapted_intrinsic[1, 1] *= float(new_resolution[1]) / float(original_resolution[1]) |
|
adapted_intrinsic[0, 2] *= float(new_resolution[0] - 1) / float(original_resolution[0] - 1) |
|
adapted_intrinsic[1, 2] *= float(new_resolution[1] - 1) / float(original_resolution[1] - 1) |
|
return adapted_intrinsic |
|
|
|
def get_mesh_projections(self): |
|
N_Large = 2000000*250 |
|
|
|
points, colors = self.load_ply(self.mesh) |
|
points, colors = torch.from_numpy(points).cuda(), torch.from_numpy(colors).cuda() |
|
|
|
intrinsic = self.adjust_intrinsic(np.loadtxt(self.intrinsics[0]), self.image_resolution, self.depth_resolution) |
|
intrinsics = torch.from_numpy(np.stack([intrinsic for frame_id in range(len(self.poses))])).cuda() |
|
extrinsics = torch.linalg.inv(torch.from_numpy(np.stack([np.loadtxt(pose) for pose in self.poses])).cuda()) |
|
|
|
if extrinsics.shape[0]*points.shape[0] < N_Large: |
|
word2cam_mat = torch.einsum('bij, jk -> bik',torch.einsum('bij,bjk -> bik', intrinsics,extrinsics), points.T).permute(0,2,1) |
|
else: |
|
B_size = 800000 |
|
Num_Points = points.shape[0] |
|
Num_batches = Num_Points//B_size+1 |
|
word2cam_mat = [] |
|
for b_i in range(Num_batches): |
|
dim_start = b_i*B_size |
|
dim_last = (b_i+1)*B_size if b_i != Num_batches-1 else points.shape[0] |
|
word2cam_mat_i = torch.einsum('bij, jk -> bik',torch.einsum('bij,bjk -> bik', intrinsics,extrinsics), points[dim_start:dim_last].T).permute(0,2,1) |
|
word2cam_mat.append(word2cam_mat_i.cpu()) |
|
word2cam_mat = torch.cat(word2cam_mat, dim = 1) |
|
del intrinsics |
|
del extrinsics |
|
del points |
|
del colors |
|
torch.cuda.empty_cache() |
|
|
|
point_depth = word2cam_mat[:, :, 2].cuda() |
|
if word2cam_mat.shape[1]*word2cam_mat.shape[0] < N_Large: |
|
size = (word2cam_mat.shape[0], word2cam_mat.shape[1]) |
|
mask = (word2cam_mat[:, :, 2] != 0).reshape(size[0]*size[1]) |
|
|
|
projected_points = torch.stack([(word2cam_mat[:, :, 0].reshape(size[0]*size[1])[mask]/word2cam_mat[:, :, 2].reshape(size[0]*size[1])[mask]).reshape(size), |
|
(word2cam_mat[:, :, 1].reshape(size[0]*size[1])[mask]/word2cam_mat[:, :, 2].reshape(size[0]*size[1])[mask]).reshape(size)]).permute(1,2,0).long() |
|
inside_mask = ((projected_points[:,:,0] < self.width)*(projected_points[:,:,0] > 0)*(projected_points[:,:,1] < self.height)*(projected_points[:,:,1] >0) == 1 ) |
|
|
|
else: |
|
B_size = 200000 |
|
Num_Points = word2cam_mat.shape[1] |
|
Num_batches = Num_Points//B_size+1 |
|
projected_points = [] |
|
|
|
for b_i in range(Num_batches): |
|
dim_start = b_i*B_size |
|
dim_last = (b_i+1)*B_size if b_i != Num_batches-1 else word2cam_mat.shape[1] |
|
batch_z = word2cam_mat[:, dim_start:dim_last, 2].cuda() |
|
batch_y = word2cam_mat[:, dim_start:dim_last, 1].cuda() |
|
batch_x = word2cam_mat[:, dim_start:dim_last, 0].cuda() |
|
|
|
size = (word2cam_mat.shape[0], dim_last-dim_start) |
|
mask = (batch_z != 0).reshape(size[0]*size[1]) |
|
projected_points_i = torch.stack([(torch.div(batch_x.reshape(size[0]*size[1])[mask],batch_z.reshape(size[0]*size[1])[mask])).reshape(size), |
|
(torch.div(batch_y.reshape(size[0]*size[1])[mask],batch_z.reshape(size[0]*size[1])[mask])).reshape(size)]).permute(1,2,0).long() |
|
projected_points.append(projected_points_i.cpu()) |
|
|
|
|
|
|
|
|
|
projected_points = torch.cat(projected_points, dim = 1) |
|
inside_mask = ((projected_points[:,:,0] < self.width)*(projected_points[:,:,0] > 0)*(projected_points[:,:,1] < self.height)*(projected_points[:,:,1] >0) == 1 ) |
|
|
|
|
|
|
|
depth_maps = self.load_depth_maps() |
|
num_frames = depth_maps.shape[0] |
|
|
|
for frame_id in range(num_frames): |
|
points_in_frame_mask = inside_mask[frame_id].clone() |
|
points_in_frame = (projected_points[frame_id][points_in_frame_mask]) |
|
depth_in_frame = point_depth[frame_id][points_in_frame_mask] |
|
visibility_mask = (torch.abs(depth_maps[frame_id][points_in_frame[:,1].long(), points_in_frame[:,0].long()] |
|
- depth_in_frame) <= \ |
|
self.vis_depth_threshold) |
|
|
|
inside_mask[frame_id][points_in_frame_mask] = visibility_mask.to(inside_mask.device) |
|
|
|
return projected_points.type(torch.int16).cpu(), inside_mask.cpu() |