sam_3d / sam_3d.py
JeffLiang
update
092977e
raw
history blame
11.9 kB
import os
import copy
import random
from PIL import Image
import torch
import matplotlib.pyplot as plt
import numpy as np
from plyfile import PlyData
from segment_anything import SamPredictor, sam_model_registry
def get_image_ids(path):
files = os.listdir(path)
files = [f.split('.')[0] for f in files if os.path.isfile(path+'/'+f)] #Filtering only the files.
return sorted(files)
def load_align_matrix_from_txt(path):
lines = open(path).readlines()
# test set data doesn't have align_matrix
axis_align_matrix = np.eye(4)
for line in lines:
if 'axisAlignment' in line:
axis_align_matrix = [
float(x)
for x in line.rstrip().strip('axisAlignment = ').split(' ')
]
break
axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4))
return axis_align_matrix
def load_matrix_from_txt(path, shape=(4, 4)):
with open(path) as f:
txt = f.readlines()
txt = ''.join(txt).replace('\n', ' ')
matrix = [float(v) for v in txt.split()]
return np.array(matrix).reshape(shape)
def load_image(path):
image = Image.open(path)
return np.array(image)
def convert_from_uvd(u, v, d, intr, pose, align):
extr = np.linalg.inv(pose)
if d == 0:
return None, None, None
fx = intr[0, 0]
fy = intr[1, 1]
cx = intr[0, 2]
cy = intr[1, 2]
depth_scale = 1000
z = d / depth_scale
x = (u - cx) * z / fx
y = (v - cy) * z / fy
world = (align @ pose @ np.array([x, y, z, 1]))
return world[:3] / world[3]
# Find the cloest point in the cloud with select
def find_closest_point(point, point_cloud, num=1):
# calculate the Euclidean distances between the input vector and each row of the matrix
distances = np.linalg.norm(point_cloud - point, axis=1)
# find the index of the row with the minimum distance
closest_index = np.argsort(distances)[:num]
# get the closest vector from the matrix
closest_vector = point_cloud[closest_index]
return closest_index, closest_vector
def plot_3d(xdata, ydata, zdata, color=None, b_min=2, b_max=8, view=(45, 45)):
fig, ax = plt.subplots(subplot_kw={"projection": "3d"}, dpi=200)
ax.view_init(view[0], view[1])
ax.set_xlim(b_min, b_max)
ax.set_ylim(b_min, b_max)
ax.set_zlim(b_min, b_max)
ax.scatter3D(xdata, ydata, zdata, c=color, cmap='rgb', s=0.1)
class SAM3DDemo(object):
def __init__(self, sam_model, sam_ckpt, scene_name):
sam = sam_model_registry[sam_model](checkpoint=sam_ckpt)
self.predictor = SamPredictor(sam)
self.scene_name = scene_name
scene_path = os.path.join('./scannet_data', scene_name)
self.color_path = os.path.join(scene_path, 'color')
self.depth_path = os.path.join(scene_path, 'depth')
self.pose_path = os.path.join(scene_path, 'pose')
self.intrinsic_path = os.path.join(scene_path, 'intrinsic')
self.align_matirx_path = f'{scene_path}/{scene_name}.txt'
self.img_ids = get_image_ids(self.color_path)
self.align_matrix = load_align_matrix_from_txt(self.align_matirx_path)
self.intrinsic_depth = load_matrix_from_txt(os.path.join(self.intrinsic_path, 'intrinsic_depth.txt'))
self.poses = [load_matrix_from_txt(os.path.join(self.pose_path, f'{i}.txt')) for i in self.img_ids]
self.rgb_images = [load_image(os.path.join(self.color_path, f'{i}.jpg')) for i in self.img_ids]
self.depth_images = [load_image(os.path.join(self.depth_path, f'{i}.png'))for i in self.img_ids]
def project_3D_to_images(self, select_points, valid_margin=20):
valid_img_ids = []
valid_points = {}
for img_i in range(len(self.img_ids)):
rgb_img = self.rgb_images[img_i]
depth_img = self.depth_images[img_i]
extrinsics = self.poses[img_i]
projection_matrix = self.intrinsic_depth @ np.linalg.inv(self.align_matrix @ extrinsics)
raw_points = np.vstack((select_points.T, np.ones((1, select_points.T.shape[1]))))
raw_points = np.dot(projection_matrix, raw_points)
# bounding simplest
points = raw_points[:2, :] / raw_points[2, :]
points = np.round(points).astype(np.int32)
valid = (points[0] >= valid_margin).all() & (points[1] >= valid_margin).all() \
& (points[0] < (rgb_img.shape[1] - valid_margin)).all() & (points[1] < (rgb_img.shape[0] - valid_margin)).all() \
& (raw_points[2, :] > 0).all()
if valid:
depth_margin = 0.4
gt_depths = depth_img[points[1], points[0]] / 1000
proj_depths = raw_points[2, :]
if (proj_depths[0] > (1 - depth_margin / 2.0) * gt_depths[0]) & (proj_depths[0] < (1 + depth_margin / 2.0) * gt_depths[0]):
valid_img_ids.append(img_i)
valid_points[img_i] = points
show_id = valid_img_ids[-1]
show_points = valid_points[show_id]
rgb_img = self.rgb_images[show_id]
fig, ax = plt.subplots()
ax.imshow(rgb_img)
for x, y in zip(show_points[0], show_points[1]):
ax.plot(x, y, 'ro')
canvas = fig.canvas
canvas.draw()
w, h = canvas.get_width_height()
rgb_img_w_points = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3)
print("projecting 3D point to images successfully...")
return valid_img_ids, valid_points, rgb_img_w_points
def process_img_w_sam(self, valid_img_ids, valid_points, granularity):
mask_colors = []
for img_i in range(len(self.img_ids)):
rgb_img = self.rgb_images[img_i]
msk_color = np.full(rgb_img.shape, 0.5)
if img_i in valid_img_ids:
self.predictor.set_image(rgb_img)
point_coor = valid_points[img_i].T[0][None]
masks, _, _ = self.predictor.predict(point_coords=point_coor, point_labels=np.array([1]))
# fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(10, 5))
# for i in range(3):
# mask_img = masks[i][:,:,None] * rgb_img
# axs[i].set_title(f'granularity {i}')
# axs[i].imshow(mask_img)
m = masks[granularity]
msk_color[m] = [0, 0, 1.0]
mask_colors.append(msk_color)
show_id = valid_img_ids[-1]
rgb_img = self.rgb_images[show_id]
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(24, 8))
for i in range(3):
mask_img = masks[i][:,:,None] * rgb_img
axs[i].set_title(f'granularity {i}')
axs[i].imshow(mask_img)
canvas = fig.canvas
canvas.draw()
w, h = canvas.get_width_height()
rgb_img_w_masks = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3)
print("processing images with SAM successfully...")
return mask_colors, rgb_img_w_masks
def project_mask_to_3d(self, mask_colors, sample_ratio=0.002):
x_data, y_data, z_data, c_data = [], [], [], []
for img_i in range(len(self.img_ids)):
id = self.img_ids[img_i]
# RGBD
d = self.depth_images[img_i]
c = self.rgb_images[img_i]
p = self.poses[img_i]
msk_color = mask_colors[img_i]
# Projecting RGB features into the point space
for i in range(d.shape[0]):
for j in range(d.shape[1]):
if random.random() < sample_ratio:
x, y, z = convert_from_uvd(j, i, d[i, j], self.intrinsic_depth, p, self.align_matrix)
if x is None:
continue
x_data.append(x)
y_data.append(y)
z_data.append(z)
ci = int(i * c.shape[0] / d.shape[0])
cj = int(j * c.shape[1] / d.shape[1])
c_data.append([msk_color[ci, cj]])
print("reprojecting images to 3D points successfully...")
return x_data, y_data, z_data, c_data
def match_projected_point_to_gt_point(self, x_data, y_data, z_data, c_data, gt_coords):
c_data = torch.tensor(np.concatenate(c_data, axis=0))
img_coords = np.array([x_data, y_data, z_data], dtype=np.float32).T
gt_quant_coords = np.floor_divide(gt_coords, 0.2)
img_quant_coords = np.floor_divide(img_coords, 0.2)
# Remove the reduandant coords
unique_gt_coords, gt_inverse_indices = np.unique(gt_quant_coords, axis=0, return_inverse=True)
unique_img_coords, img_inverse_indices = np.unique(img_quant_coords, axis=0, return_inverse=True)
# Match the coords in gt_coords to img_corrds
def find_loc(vec):
obj = np.empty((), dtype=object)
out = np.where((unique_img_coords == vec).all(1))[0]
obj[()] = out
return obj
gt_2_img_map = np.apply_along_axis(find_loc, 1, unique_gt_coords)
# Since some places are empty, using the simple round interplation
gt_2_img_map_filled = []
start_id = np.array([0])
for loc in gt_2_img_map:
if not np.any(loc):
loc = start_id
else:
start_id = loc
gt_2_img_map_filled.append(int(loc))
mean_colors = []
for i in range(unique_img_coords.shape[0]):
valid_locs = np.where(img_inverse_indices == i)
mean_f = torch.mean(c_data[valid_locs], axis=0)
# mean_f, _ = torch.mode(c_data[valid_locs], dim=0)
mean_colors.append(mean_f.unsqueeze(0))
mean_colors = torch.cat(mean_colors)
# Project the averaged features back to groundtruth point clouds
img_2_gt_colors = mean_colors[gt_2_img_map_filled]
projected_gt_colors = img_2_gt_colors[gt_inverse_indices]
print("convert projected points to GT points successfully...")
return projected_gt_colors
def render_point_cloud(self, data, color):
data_copy = copy.copy(data)
uint_color = torch.round(torch.tensor(color) * 255).to(torch.uint8)
data_copy['red'] = uint_color[:, 0]
data_copy['green'] = uint_color[:, 1]
data_copy['blue'] = uint_color[:, 2]
return data_copy
def run_with_coord(self, point, granularity):
x_data, y_data, z_data, c_data = [], [], [], []
plydata = PlyData.read(f"./scannet_data/{self.scene_name}/{self.scene_name}.ply")
data = plydata.elements[0].data
# gt_coords stand for the groudtruth point clouds coordinates
gt_coords = np.array([data['x'], data['y'], data['z']], dtype=np.float32).T
gt_color = np.array([data['red'], data['green'], data['blue']], dtype=np.float32).T
blank_color = np.full(gt_color.shape, 0.5)
select_index, select_points = find_closest_point(point, gt_coords, num=10)
point_select_color = blank_color.copy()
point_select_color[select_index] = [1.0, 0, 0]
data_point_select = self.render_point_cloud(data, point_select_color)
valid_img_ids, valid_points, rgb_img_w_points = self.project_3D_to_images(select_points)
mask_colors, rgb_img_w_masks = self.process_img_w_sam(valid_img_ids, valid_points, granularity)
x_data, y_data, z_data, c_data = self.project_mask_to_3d(mask_colors)
projected_gt_colors = self.match_projected_point_to_gt_point(x_data, y_data, z_data, c_data, gt_coords)
data_final = self.render_point_cloud(data, projected_gt_colors)
return data_point_select, rgb_img_w_points, rgb_img_w_masks, data_final