from utils.utils_3d import Network_3D |
from utils.utils_2d import Network_2D, load_yaml |
import time |
import torch |
import os |
import os.path as osp |
import imageio |
import glob |
import open3d as o3d |
import numpy as np |
import math |
from models.Mask3D.mask3d import load_mesh, load_ply |
import colorsys |
from tqdm import tqdm |
def get_iou(masks): |
masks = masks.float() |
intersection = torch.einsum('ij,kj -> ik', masks, masks) |
num_masks = masks.shape[0] |
masks_batch_size = 2 |
if masks_batch_size < num_masks: |
ratio = num_masks//masks_batch_size |
remaining = num_masks-ratio*masks_batch_size |
start_masks = list(range(0,ratio*masks_batch_size, masks_batch_size)) |
if remaining == 0: |
end_masks = list(range(masks_batch_size,(ratio+1)*masks_batch_size,masks_batch_size)) |
else: |
end_masks = list(range(masks_batch_size,(ratio+1)*masks_batch_size,masks_batch_size)) |
end_masks[-1] = num_masks |
else: |
start_masks = [0] |
end_masks = [num_masks] |
union = torch.cat([((masks[st:ed, None, :]+masks[None, :, :]) >= 1).sum(-1) for st,ed in zip(start_masks, end_masks)]) |
iou = torch.div(intersection,union) |
return iou |
def apply_nms(masks, scores, nms_th): |
masks = masks.permute(1,0) |
scored_sorted, sorted_scores_indices = torch.sort(scores, descending=True) |
inv_sorted_scores_indices = {sorted_id.item(): id for id, sorted_id in enumerate(sorted_scores_indices)} |
maskes_sorted = masks[sorted_scores_indices] |
iou = get_iou(maskes_sorted) |
available_indices = torch.arange(len(scored_sorted)) |
for indx in range(len(available_indices)): |
remove_indices = torch.where(iou[indx,indx+1:] > nms_th)[0] |
available_indices[indx+1:][remove_indices] = 0 |
remaining = available_indices.unique() |
keep_indices = torch.tensor([inv_sorted_scores_indices[id.item()] for id in remaining]) |
return keep_indices |
def generate_vibrant_colors(num_colors): |
colors = [] |
hue_increment = 1.0 / num_colors |
saturation = 1.0 |
value = 1.0 |
for i in range(num_colors): |
hue = i * hue_increment |
rgb = colorsys.hsv_to_rgb(hue, saturation, value) |
colors.append(rgb) |
return colors |
def get_visibility_mat(pred_masks_3d, inside_mask, topk = 15): |
intersection = torch.einsum("ik, fk -> if", pred_masks_3d.float(), inside_mask.float()) |
total_point_number = pred_masks_3d[:, None, :].float().sum(dim = -1) |
visibility_matrix = intersection/total_point_number |
if topk > visibility_matrix.shape[-1]: |
topk = visibility_matrix.shape[-1] |
max_visiblity_in_frame = torch.topk(visibility_matrix, topk, dim = -1).indices |
visibility_matrix_bool = torch.zeros_like(visibility_matrix).bool() |
visibility_matrix_bool[torch.tensor(range(len(visibility_matrix_bool)))[:, None],max_visiblity_in_frame] = True |
return visibility_matrix_bool |
def compute_iou(box, boxes): |
assert box.shape == (4,), "Reference box must be of shape (4,)" |
assert boxes.shape[1] == 4, "Boxes must be of shape (N, 4)" |
x1_inter = torch.max(box[0], boxes[:, 0]) |
y1_inter = torch.max(box[1], boxes[:, 1]) |
x2_inter = torch.min(box[2], boxes[:, 2]) |
y2_inter = torch.min(box[3], boxes[:, 3]) |
inter_area = (x2_inter - x1_inter).clamp(0) * (y2_inter - y1_inter).clamp(0) |
box_area = (box[2] - box[0]) * (box[3] - box[1]) |
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) |
union_area = box_area + boxes_area - inter_area |
iou = inter_area / union_area |
return iou |
class OpenYolo3D(): |
def __init__(self, openyolo3d_config = ""): |
config = load_yaml(openyolo3d_config) |
self.network_3d = Network_3D(config) |
self.network_2d = Network_2D(config) |
self.openyolo3d_config = config |
def predict(self, path_2_scene_data, depth_scale, processed_scene = None, path_to_3d_masks = None, is_gt=False): |
self.world2cam = WORLD_2_CAM(path_2_scene_data, depth_scale, self.openyolo3d_config) |
self.mesh_projections = self.world2cam.get_mesh_projections() |
self.scaling_params = [self.world2cam.depth_resolution[0]/self.world2cam.image_resolution[0], self.world2cam.depth_resolution[1]/self.world2cam.image_resolution[1]] |
scene_name = path_2_scene_data.split("/")[-1] |
print("[π ACTION] 3D mask proposals computation ...") |
start = time.time() |
if path_to_3d_masks is None: |
self.preds_3d = self.network_3d.get_class_agnostic_masks(self.world2cam.mesh) if processed_scene is None else self.network_3d.get_class_agnostic_masks(processed_scene) |
keep_score = self.preds_3d[1] >= self.openyolo3d_config["network3d"]["th"] |
keep_nms = apply_nms(self.preds_3d[0][:, keep_score].cuda(), self.preds_3d[1][keep_score].cuda(), self.openyolo3d_config["network3d"]["nms"]) |
self.preds_3d = (self.preds_3d[0].cpu().permute(1,0)[keep_score][keep_nms].permute(1,0), self.preds_3d[1].cpu()[keep_score][keep_nms]) |
else: |
self.preds_3d = torch.load(osp.join(path_to_3d_masks, f"{scene_name}.pt")) |
print(f"[π INFO] Elapsed time {(time.time()-start)}") |
INFO] Proposals computed.") |
print("[π ACTION] 2D Bounding Boxes computation ...") |
start = time.time() |
self.preds_2d = self.network_2d.get_bounding_boxes(self.world2cam.color_paths) |
print(f"[π INFO] Elapsed time {(time.time()-start)}") |
INFO] Bounding boxes computed.") |
print("[π ACTION] Predicting ...") |
start = time.time() |
prediction = self.label_3d_masks_from_2d_bboxes(scene_name, is_gt) |
print(f"[π INFO] Elapsed time {(time.time()-start)}") |
INFO] Prediction completed") |
return prediction |
def label_3d_masks_from_2d_bboxes(self, scene_name, is_gt=False): |
projections_mesh_to_frame , keep_visible_points = self.mesh_projections |
predictions_2d_bboxes = self.preds_2d |
prediction_3d_masks, _ = self.preds_3d |
predicted_masks, predicated_classes, predicated_scores = self.label_3d_masks_from_label_maps(prediction_3d_masks.bool(), |
predictions_2d_bboxes, |
projections_mesh_to_frame, |
keep_visible_points, |
is_gt) |
self.predicted_masks = predicted_masks |
self.predicated_scores = predicated_scores |
self.predicated_classes = predicated_classes |
return {scene_name : (predicted_masks, predicated_classes, predicated_scores)} |
def label_3d_masks_from_label_maps(self, |
prediction_3d_masks, |
predictions_2d_bboxes, |
projections_mesh_to_frame, |
keep_visible_points, |
is_gt): |
label_maps = self.construct_label_maps(predictions_2d_bboxes) |
visibility_matrix = get_visibility_mat(prediction_3d_masks.cuda().permute(1,0), keep_visible_points.cuda(), topk = 25 if is_gt else self.openyolo3d_config["openyolo3d"]["topk"]) |
valid_frames = visibility_matrix.sum(dim=0) >= 1 |
prediction_3d_masks = prediction_3d_masks.permute(1,0).cpu() |
prediction_3d_masks_np = prediction_3d_masks.numpy() |
projections_mesh_to_frame = projections_mesh_to_frame[valid_frames].cpu().numpy() |
visibility_matrix = visibility_matrix[:, valid_frames].cpu().numpy() |
keep_visible_points = keep_visible_points[valid_frames].cpu().numpy() |
distributions = [] |
class_labels = [] |
class_probs = [] |
class_dists = [] |
label_maps = label_maps[valid_frames].numpy() |
bounding_boxes = predictions_2d_bboxes.values() |
bounding_boxes_valid = [bbox for (bi, bbox) in enumerate(bounding_boxes) if valid_frames[bi]] |
for mask_id, mask in enumerate(prediction_3d_masks_np): |
prob_normalizer = 0 |
representitive_frame_ids = np.where(visibility_matrix[mask_id])[0] |
labels_distribution = [] |
iou_vals = [] |
for representitive_frame_id in representitive_frame_ids: |
visible_points_mask = (keep_visible_points[representitive_frame_id].squeeze()*mask).astype(bool) |
prob_normalizer += visible_points_mask.sum() |
instance_x_y_coords = projections_mesh_to_frame[representitive_frame_id][np.where(visible_points_mask)].astype(np.int64) |
boxes = bounding_boxes_valid[representitive_frame_id]["bbox"].long() |
if len(boxes) > 0 and len(instance_x_y_coords > 10): |
x_l, x_r, y_t, y_b = instance_x_y_coords[:, 0].min(), instance_x_y_coords[:, 0].max()+1, instance_x_y_coords[:, 1].min(), instance_x_y_coords[:, 1].max()+1 |
box = torch.tensor([x_l, y_t, x_r, y_b]) |
iou_values = compute_iou(box, boxes) |
iou_vals.append(iou_values.max().item()) |
selected_labels = label_maps[representitive_frame_id, instance_x_y_coords[:, 1], instance_x_y_coords[:, 0]] |
labels_distribution.append(selected_labels) |
labels_distribution = np.concatenate(labels_distribution) if len(labels_distribution) > 0 else np.array([-1]) |
distribution = torch.zeros(self.openyolo3d_config["openyolo3d"]["num_classes"]) if self.openyolo3d_config["openyolo3d"]["topk_per_image"] != -1 else None |
if (labels_distribution != -1).sum() != 0: |
if distribution is not None: |
all_labels = torch.from_numpy(labels_distribution[labels_distribution != -1]) |
all_labels_unique = all_labels.unique() |
for lb in all_labels_unique: |
distribution[lb] = (all_labels == lb).sum() |
distribution = distribution/distribution.max() |
class_label = torch.mode(torch.from_numpy(labels_distribution[labels_distribution != -1])).values.item() |
class_prob = (labels_distribution == class_label).sum()/prob_normalizer |
else: |
if distribution is not None: |
distribution[-1] = 1.0 |
class_label = -1 |
class_prob = 0.0 |
iou_vals = torch.tensor(iou_vals) |
class_labels.append(class_label) |
if (iou_vals != 0).sum(): |
iou_prob = iou_vals[iou_vals != 0].mean().item() |
else: |
iou_prob = 0.0 |
class_probs.append(class_prob*iou_prob) |
if distribution is not None: |
distributions.append(distribution) |
pred_classes = torch.tensor(class_labels) |
pred_scores = torch.tensor(class_probs) |
if distribution is not None: |
distributions = torch.stack(distributions) if len(distributions) > 0 else torch.tensor((0, self.openyolo3d_config["openyolo3d"]["num_classes"])) |
if (self.openyolo3d_config["openyolo3d"]["topk_per_image"] != -1) and (not is_gt): |
n_instance = distributions.shape[0] |
distributions = distributions.reshape(-1) |
labels = ( |
torch.arange(self.openyolo3d_config["openyolo3d"]["num_classes"], device=distributions.device) |
.unsqueeze(0) |
.repeat(n_instance, 1) |
.flatten(0, 1) |
) |
cur_topk = self.openyolo3d_config["openyolo3d"]["topk_per_image"] |
_, idx = torch.topk(distributions, k=min(cur_topk, len(distributions)), largest=True) |
mask_idx = torch.div(idx, self.openyolo3d_config["openyolo3d"]["num_classes"], rounding_mode="floor") |
pred_classes = labels[idx] |
pred_scores = distributions[idx].cuda() |
prediction_3d_masks = prediction_3d_masks[mask_idx] |
return prediction_3d_masks.permute(1,0), pred_classes, pred_scores |
def construct_label_maps(self, predictions_2d_bboxes, save_label_map=False): |
label_maps = (torch.ones((len(predictions_2d_bboxes), self.world2cam.height, self.world2cam.width))*-1).type(torch.int16) |
for frame_id, pred in enumerate(predictions_2d_bboxes.values()): |
bboxes = pred["bbox"].long() |
labels = pred["labels"].type(torch.int16) |
bboxes[:,0] = bboxes[:,0]*self.scaling_params[1] |
bboxes[:,2] = bboxes[:,2]*self.scaling_params[1] |
bboxes[:,1] = bboxes[:,1]*self.scaling_params[0] |
bboxes[:,3] = bboxes[:,3]*self.scaling_params[0] |
bboxes_weights = (bboxes[:,2]-bboxes[:,0])+(bboxes[:,3]-bboxes[:,1]) |
sorted_indices = bboxes_weights.sort(descending=True).indices |
bboxes = bboxes[sorted_indices] |
labels = labels[sorted_indices] |
for id, bbox in enumerate(bboxes): |
label_maps[frame_id, bbox[1]:bbox[3],bbox[0]:bbox[2]] = labels[id] |
return label_maps |
def save_output_as_ply(self, save_path, highest_score = True): |
if highest_score : |
th = self.predicated_scores.max() |
else: |
th = self.predicated_scores.max()-0.1 |
mesh = load_mesh(self.world2cam.mesh) |
vertex_colors = np.asarray(mesh.vertex_colors) |
vibrant_colors = generate_vibrant_colors(len(self.predicated_scores[self.predicated_scores >= th])) |
color_id = 0 |
for i, class_id in enumerate(self.predicated_classes): |
if self.predicated_scores[i] < th: |
continue |
if len(vibrant_colors) == 0: |
break |
mask = self.predicted_masks.permute(1,0)[i] |
vertex_colors[mask] = np.array(vibrant_colors.pop()) |
color_id += 1 |
mesh.vertex_colors = o3d.utility.Vector3dVector(vertex_colors) |
o3d.io.write_triangle_mesh(save_path, mesh) |
class WORLD_2_CAM(): |
def __init__(self, path_2_scene, depth_scale, openyolo3d_config = None): |
self.poses = {} |
self.intrinsics = {} |
self.meshes = {} |
self.depth_maps_paths = {} |
self.depth_color_paths = {} |
self.vis_depth_threshold = openyolo3d_config["openyolo3d"]['vis_depth_threshold'] |
frequency = openyolo3d_config["openyolo3d"]['frequency'] |
path_2_poses = osp.join(path_2_scene,"poses") |
num_frames = len(os.listdir(path_2_poses)) |
self.poses = [osp.join(path_2_poses, f"{i}.txt") for i in list(range(num_frames))[::frequency]] |
path_2_intrinsics = osp.join(path_2_scene,"intrinsics.txt") |
self.intrinsics = [path_2_intrinsics for i in list(range(num_frames))[::frequency]] |
self.mesh = glob.glob(path_2_scene+"/*.ply")[0] |
path_2_depth = osp.join(path_2_scene,"depth") |
self.depth_maps_paths = [osp.join(path_2_depth, f"{i}.png") for i in list(range(num_frames))[::frequency]] |
path_2_color = osp.join(path_2_scene,"color") |
self.color_paths = [osp.join(path_2_color, f"{i}.jpg") for i in list(range(num_frames))[::frequency]] |
self.image_resolution = imageio.imread(list(self.color_paths)[0]).shape[:2] |
self.depth_resolution = imageio.imread(list(self.depth_maps_paths)[0]).shape |
self.height = self.depth_resolution[0] |
self.width = self.depth_resolution[1] |
self.depth_scale = depth_scale |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
@staticmethod |
def load_ply(path_2_mesh): |
pcd = o3d.io.read_point_cloud(path_2_mesh) |
points = np.asarray(pcd.points) |
colors = np.asarray(pcd.colors) |
coords = np.concatenate([points, np.ones((points.shape[0], 1))], axis = -1) |
return coords, colors |
def load_depth_maps(self): |
depth_maps = [] |
paths_to_depth_maps_scene_i = self.depth_maps_paths |
for depth_map_path_i in paths_to_depth_maps_scene_i: |
depth_path = os.path.join(depth_map_path_i) |
depth_maps.append(torch.from_numpy(imageio.imread(depth_path) / self.depth_scale).to(self.device)) |
return torch.stack(depth_maps) |
def adjust_intrinsic(self, intrinsic, original_resolution, new_resolution): |
if original_resolution == new_resolution: |
return intrinsic |
resize_width = int(math.floor(new_resolution[1] * float( |
original_resolution[0]) / float(original_resolution[1]))) |
adapted_intrinsic = intrinsic.copy() |
adapted_intrinsic[0, 0] *= float(resize_width) / float(original_resolution[0]) |
adapted_intrinsic[1, 1] *= float(new_resolution[1]) / float(original_resolution[1]) |
adapted_intrinsic[0, 2] *= float(new_resolution[0] - 1) / float(original_resolution[0] - 1) |
adapted_intrinsic[1, 2] *= float(new_resolution[1] - 1) / float(original_resolution[1] - 1) |
return adapted_intrinsic |
def get_mesh_projections(self): |
N_Large = 2000000*250 |
points, colors = self.load_ply(self.mesh) |
points, colors = torch.from_numpy(points).cuda(), torch.from_numpy(colors).cuda() |
intrinsic = self.adjust_intrinsic(np.loadtxt(self.intrinsics[0]), self.image_resolution, self.depth_resolution) |
intrinsics = torch.from_numpy(np.stack([intrinsic for frame_id in range(len(self.poses))])).cuda() |
extrinsics = torch.linalg.inv(torch.from_numpy(np.stack([np.loadtxt(pose) for pose in self.poses])).cuda()) |
if extrinsics.shape[0]*points.shape[0] < N_Large: |
word2cam_mat = torch.einsum('bij, jk -> bik',torch.einsum('bij,bjk -> bik', intrinsics,extrinsics), points.T).permute(0,2,1) |
else: |
B_size = 800000 |
Num_Points = points.shape[0] |
Num_batches = Num_Points//B_size+1 |
word2cam_mat = [] |
for b_i in range(Num_batches): |
dim_start = b_i*B_size |
dim_last = (b_i+1)*B_size if b_i != Num_batches-1 else points.shape[0] |
word2cam_mat_i = torch.einsum('bij, jk -> bik',torch.einsum('bij,bjk -> bik', intrinsics,extrinsics), points[dim_start:dim_last].T).permute(0,2,1) |
word2cam_mat.append(word2cam_mat_i.cpu()) |
word2cam_mat = torch.cat(word2cam_mat, dim = 1) |
del intrinsics |
del extrinsics |
del points |
del colors |
torch.cuda.empty_cache() |
point_depth = word2cam_mat[:, :, 2].cuda() |
if word2cam_mat.shape[1]*word2cam_mat.shape[0] < N_Large: |
size = (word2cam_mat.shape[0], word2cam_mat.shape[1]) |
mask = (word2cam_mat[:, :, 2] != 0).reshape(size[0]*size[1]) |
projected_points = torch.stack([(word2cam_mat[:, :, 0].reshape(size[0]*size[1])[mask]/word2cam_mat[:, :, 2].reshape(size[0]*size[1])[mask]).reshape(size), |
(word2cam_mat[:, :, 1].reshape(size[0]*size[1])[mask]/word2cam_mat[:, :, 2].reshape(size[0]*size[1])[mask]).reshape(size)]).permute(1,2,0).long() |
inside_mask = ((projected_points[:,:,0] < self.width)*(projected_points[:,:,0] > 0)*(projected_points[:,:,1] < self.height)*(projected_points[:,:,1] >0) == 1 ) |
else: |
B_size = 200000 |
Num_Points = word2cam_mat.shape[1] |
Num_batches = Num_Points//B_size+1 |
projected_points = [] |
for b_i in range(Num_batches): |
dim_start = b_i*B_size |
dim_last = (b_i+1)*B_size if b_i != Num_batches-1 else word2cam_mat.shape[1] |
batch_z = word2cam_mat[:, dim_start:dim_last, 2].cuda() |
batch_y = word2cam_mat[:, dim_start:dim_last, 1].cuda() |
batch_x = word2cam_mat[:, dim_start:dim_last, 0].cuda() |
size = (word2cam_mat.shape[0], dim_last-dim_start) |
mask = (batch_z != 0).reshape(size[0]*size[1]) |
projected_points_i = torch.stack([(torch.div(batch_x.reshape(size[0]*size[1])[mask],batch_z.reshape(size[0]*size[1])[mask])).reshape(size), |
(torch.div(batch_y.reshape(size[0]*size[1])[mask],batch_z.reshape(size[0]*size[1])[mask])).reshape(size)]).permute(1,2,0).long() |
projected_points.append(projected_points_i.cpu()) |
projected_points = torch.cat(projected_points, dim = 1) |
inside_mask = ((projected_points[:,:,0] < self.width)*(projected_points[:,:,0] > 0)*(projected_points[:,:,1] < self.height)*(projected_points[:,:,1] >0) == 1 ) |
depth_maps = self.load_depth_maps() |
num_frames = depth_maps.shape[0] |
for frame_id in range(num_frames): |
points_in_frame_mask = inside_mask[frame_id].clone() |
points_in_frame = (projected_points[frame_id][points_in_frame_mask]) |
depth_in_frame = point_depth[frame_id][points_in_frame_mask] |
visibility_mask = (torch.abs(depth_maps[frame_id][points_in_frame[:,1].long(), points_in_frame[:,0].long()] |
- depth_in_frame) <= \ |
self.vis_depth_threshold) |
inside_mask[frame_id][points_in_frame_mask] = visibility_mask.to(inside_mask.device) |
return projected_points.type(torch.int16).cpu(), inside_mask.cpu() |