Spaces:
Runtime error
Runtime error
import os | |
import copy | |
import random | |
from PIL import Image | |
import torch | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from plyfile import PlyData | |
from segment_anything import SamPredictor, sam_model_registry | |
def get_image_ids(path): | |
files = os.listdir(path) | |
files = [f.split('.')[0] for f in files if os.path.isfile(path+'/'+f)] #Filtering only the files. | |
return sorted(files) | |
def load_align_matrix_from_txt(path): | |
lines = open(path).readlines() | |
# test set data doesn't have align_matrix | |
axis_align_matrix = np.eye(4) | |
for line in lines: | |
if 'axisAlignment' in line: | |
axis_align_matrix = [ | |
float(x) | |
for x in line.rstrip().strip('axisAlignment = ').split(' ') | |
] | |
break | |
axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4)) | |
return axis_align_matrix | |
def load_matrix_from_txt(path, shape=(4, 4)): | |
with open(path) as f: | |
txt = f.readlines() | |
txt = ''.join(txt).replace('\n', ' ') | |
matrix = [float(v) for v in txt.split()] | |
return np.array(matrix).reshape(shape) | |
def load_image(path): | |
image = Image.open(path) | |
return np.array(image) | |
def convert_from_uvd(u, v, d, intr, pose, align): | |
extr = np.linalg.inv(pose) | |
if d == 0: | |
return None, None, None | |
fx = intr[0, 0] | |
fy = intr[1, 1] | |
cx = intr[0, 2] | |
cy = intr[1, 2] | |
depth_scale = 1000 | |
z = d / depth_scale | |
x = (u - cx) * z / fx | |
y = (v - cy) * z / fy | |
world = (align @ pose @ np.array([x, y, z, 1])) | |
return world[:3] / world[3] | |
# Find the cloest point in the cloud with select | |
def find_closest_point(point, point_cloud, num=1): | |
# calculate the Euclidean distances between the input vector and each row of the matrix | |
distances = np.linalg.norm(point_cloud - point, axis=1) | |
# find the index of the row with the minimum distance | |
closest_index = np.argsort(distances)[:num] | |
# get the closest vector from the matrix | |
closest_vector = point_cloud[closest_index] | |
return closest_index, closest_vector | |
def plot_3d(xdata, ydata, zdata, color=None, b_min=2, b_max=8, view=(45, 45)): | |
fig, ax = plt.subplots(subplot_kw={"projection": "3d"}, dpi=200) | |
ax.view_init(view[0], view[1]) | |
ax.set_xlim(b_min, b_max) | |
ax.set_ylim(b_min, b_max) | |
ax.set_zlim(b_min, b_max) | |
ax.scatter3D(xdata, ydata, zdata, c=color, cmap='rgb', s=0.1) | |
class SAM3DDemo(object): | |
def __init__(self, sam_model, sam_ckpt, scene_name): | |
sam = sam_model_registry[sam_model](checkpoint=sam_ckpt) | |
self.predictor = SamPredictor(sam) | |
self.scene_name = scene_name | |
scene_path = os.path.join('./scannet_data', scene_name) | |
self.color_path = os.path.join(scene_path, 'color') | |
self.depth_path = os.path.join(scene_path, 'depth') | |
self.pose_path = os.path.join(scene_path, 'pose') | |
self.intrinsic_path = os.path.join(scene_path, 'intrinsic') | |
self.align_matirx_path = f'{scene_path}/{scene_name}.txt' | |
self.img_ids = get_image_ids(self.color_path) | |
self.align_matrix = load_align_matrix_from_txt(self.align_matirx_path) | |
self.intrinsic_depth = load_matrix_from_txt(os.path.join(self.intrinsic_path, 'intrinsic_depth.txt')) | |
self.poses = [load_matrix_from_txt(os.path.join(self.pose_path, f'{i}.txt')) for i in self.img_ids] | |
self.rgb_images = [load_image(os.path.join(self.color_path, f'{i}.jpg')) for i in self.img_ids] | |
self.depth_images = [load_image(os.path.join(self.depth_path, f'{i}.png'))for i in self.img_ids] | |
def project_3D_to_images(self, select_points, valid_margin=20): | |
valid_img_ids = [] | |
valid_points = {} | |
for img_i in range(len(self.img_ids)): | |
rgb_img = self.rgb_images[img_i] | |
depth_img = self.depth_images[img_i] | |
extrinsics = self.poses[img_i] | |
projection_matrix = self.intrinsic_depth @ np.linalg.inv(self.align_matrix @ extrinsics) | |
raw_points = np.vstack((select_points.T, np.ones((1, select_points.T.shape[1])))) | |
raw_points = np.dot(projection_matrix, raw_points) | |
# bounding simplest | |
points = raw_points[:2, :] / raw_points[2, :] | |
points = np.round(points).astype(np.int32) | |
valid = (points[0] >= valid_margin).all() & (points[1] >= valid_margin).all() \ | |
& (points[0] < (rgb_img.shape[1] - valid_margin)).all() & (points[1] < (rgb_img.shape[0] - valid_margin)).all() \ | |
& (raw_points[2, :] > 0).all() | |
if valid: | |
depth_margin = 0.4 | |
gt_depths = depth_img[points[1], points[0]] / 1000 | |
proj_depths = raw_points[2, :] | |
if (proj_depths[0] > (1 - depth_margin / 2.0) * gt_depths[0]) & (proj_depths[0] < (1 + depth_margin / 2.0) * gt_depths[0]): | |
valid_img_ids.append(img_i) | |
valid_points[img_i] = points | |
show_id = valid_img_ids[-1] | |
show_points = valid_points[show_id] | |
rgb_img = self.rgb_images[show_id] | |
fig, ax = plt.subplots() | |
ax.imshow(rgb_img) | |
for x, y in zip(show_points[0], show_points[1]): | |
ax.plot(x, y, 'ro') | |
canvas = fig.canvas | |
canvas.draw() | |
w, h = canvas.get_width_height() | |
rgb_img_w_points = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3) | |
print("projecting 3D point to images successfully...") | |
return valid_img_ids, valid_points, rgb_img_w_points | |
def process_img_w_sam(self, valid_img_ids, valid_points, granularity): | |
mask_colors = [] | |
for img_i in range(len(self.img_ids)): | |
rgb_img = self.rgb_images[img_i] | |
msk_color = np.full(rgb_img.shape, 0.5) | |
if img_i in valid_img_ids: | |
self.predictor.set_image(rgb_img) | |
point_coor = valid_points[img_i].T[0][None] | |
masks, _, _ = self.predictor.predict(point_coords=point_coor, point_labels=np.array([1])) | |
# fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(10, 5)) | |
# for i in range(3): | |
# mask_img = masks[i][:,:,None] * rgb_img | |
# axs[i].set_title(f'granularity {i}') | |
# axs[i].imshow(mask_img) | |
m = masks[granularity] | |
msk_color[m] = [0, 0, 1.0] | |
mask_colors.append(msk_color) | |
show_id = valid_img_ids[-1] | |
rgb_img = self.rgb_images[show_id] | |
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(24, 8)) | |
for i in range(3): | |
mask_img = masks[i][:,:,None] * rgb_img | |
axs[i].set_title(f'granularity {i}') | |
axs[i].imshow(mask_img) | |
canvas = fig.canvas | |
canvas.draw() | |
w, h = canvas.get_width_height() | |
rgb_img_w_masks = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3) | |
print("processing images with SAM successfully...") | |
return mask_colors, rgb_img_w_masks | |
def project_mask_to_3d(self, mask_colors, sample_ratio=0.002): | |
x_data, y_data, z_data, c_data = [], [], [], [] | |
for img_i in range(len(self.img_ids)): | |
id = self.img_ids[img_i] | |
# RGBD | |
d = self.depth_images[img_i] | |
c = self.rgb_images[img_i] | |
p = self.poses[img_i] | |
msk_color = mask_colors[img_i] | |
# Projecting RGB features into the point space | |
for i in range(d.shape[0]): | |
for j in range(d.shape[1]): | |
if random.random() < sample_ratio: | |
x, y, z = convert_from_uvd(j, i, d[i, j], self.intrinsic_depth, p, self.align_matrix) | |
if x is None: | |
continue | |
x_data.append(x) | |
y_data.append(y) | |
z_data.append(z) | |
ci = int(i * c.shape[0] / d.shape[0]) | |
cj = int(j * c.shape[1] / d.shape[1]) | |
c_data.append([msk_color[ci, cj]]) | |
print("reprojecting images to 3D points successfully...") | |
return x_data, y_data, z_data, c_data | |
def match_projected_point_to_gt_point(self, x_data, y_data, z_data, c_data, gt_coords): | |
c_data = torch.tensor(np.concatenate(c_data, axis=0)) | |
img_coords = np.array([x_data, y_data, z_data], dtype=np.float32).T | |
gt_quant_coords = np.floor_divide(gt_coords, 0.2) | |
img_quant_coords = np.floor_divide(img_coords, 0.2) | |
# Remove the reduandant coords | |
unique_gt_coords, gt_inverse_indices = np.unique(gt_quant_coords, axis=0, return_inverse=True) | |
unique_img_coords, img_inverse_indices = np.unique(img_quant_coords, axis=0, return_inverse=True) | |
# Match the coords in gt_coords to img_corrds | |
def find_loc(vec): | |
obj = np.empty((), dtype=object) | |
out = np.where((unique_img_coords == vec).all(1))[0] | |
obj[()] = out | |
return obj | |
gt_2_img_map = np.apply_along_axis(find_loc, 1, unique_gt_coords) | |
# Since some places are empty, using the simple round interplation | |
gt_2_img_map_filled = [] | |
start_id = np.array([0]) | |
for loc in gt_2_img_map: | |
if not np.any(loc): | |
loc = start_id | |
else: | |
start_id = loc | |
gt_2_img_map_filled.append(int(loc)) | |
mean_colors = [] | |
for i in range(unique_img_coords.shape[0]): | |
valid_locs = np.where(img_inverse_indices == i) | |
mean_f = torch.mean(c_data[valid_locs], axis=0) | |
# mean_f, _ = torch.mode(c_data[valid_locs], dim=0) | |
mean_colors.append(mean_f.unsqueeze(0)) | |
mean_colors = torch.cat(mean_colors) | |
# Project the averaged features back to groundtruth point clouds | |
img_2_gt_colors = mean_colors[gt_2_img_map_filled] | |
projected_gt_colors = img_2_gt_colors[gt_inverse_indices] | |
print("convert projected points to GT points successfully...") | |
return projected_gt_colors | |
def render_point_cloud(self, data, color): | |
data_copy = copy.copy(data) | |
uint_color = torch.round(torch.tensor(color) * 255).to(torch.uint8) | |
data_copy['red'] = uint_color[:, 0] | |
data_copy['green'] = uint_color[:, 1] | |
data_copy['blue'] = uint_color[:, 2] | |
return data_copy | |
def run_with_coord(self, point, granularity): | |
x_data, y_data, z_data, c_data = [], [], [], [] | |
plydata = PlyData.read(f"./scannet_data/{self.scene_name}/{self.scene_name}.ply") | |
data = plydata.elements[0].data | |
# gt_coords stand for the groudtruth point clouds coordinates | |
gt_coords = np.array([data['x'], data['y'], data['z']], dtype=np.float32).T | |
gt_color = np.array([data['red'], data['green'], data['blue']], dtype=np.float32).T | |
blank_color = np.full(gt_color.shape, 0.5) | |
select_index, select_points = find_closest_point(point, gt_coords, num=10) | |
point_select_color = blank_color.copy() | |
point_select_color[select_index] = [1.0, 0, 0] | |
data_point_select = self.render_point_cloud(data, point_select_color) | |
valid_img_ids, valid_points, rgb_img_w_points = self.project_3D_to_images(select_points) | |
mask_colors, rgb_img_w_masks = self.process_img_w_sam(valid_img_ids, valid_points, granularity) | |
x_data, y_data, z_data, c_data = self.project_mask_to_3d(mask_colors) | |
projected_gt_colors = self.match_projected_point_to_gt_point(x_data, y_data, z_data, c_data, gt_coords) | |
data_final = self.render_point_cloud(data, projected_gt_colors) | |
return data_point_select, rgb_img_w_points, rgb_img_w_masks, data_final |