Spaces:
Runtime error
Runtime error
File size: 11,898 Bytes
fcdbf88 da442e3 fcdbf88 092977e fcdbf88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 |
import os
import copy
import random
from PIL import Image
import torch
import matplotlib.pyplot as plt
import numpy as np
from plyfile import PlyData
from segment_anything import SamPredictor, sam_model_registry
def get_image_ids(path):
files = os.listdir(path)
files = [f.split('.')[0] for f in files if os.path.isfile(path+'/'+f)] #Filtering only the files.
return sorted(files)
def load_align_matrix_from_txt(path):
lines = open(path).readlines()
# test set data doesn't have align_matrix
axis_align_matrix = np.eye(4)
for line in lines:
if 'axisAlignment' in line:
axis_align_matrix = [
float(x)
for x in line.rstrip().strip('axisAlignment = ').split(' ')
]
break
axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4))
return axis_align_matrix
def load_matrix_from_txt(path, shape=(4, 4)):
with open(path) as f:
txt = f.readlines()
txt = ''.join(txt).replace('\n', ' ')
matrix = [float(v) for v in txt.split()]
return np.array(matrix).reshape(shape)
def load_image(path):
image = Image.open(path)
return np.array(image)
def convert_from_uvd(u, v, d, intr, pose, align):
extr = np.linalg.inv(pose)
if d == 0:
return None, None, None
fx = intr[0, 0]
fy = intr[1, 1]
cx = intr[0, 2]
cy = intr[1, 2]
depth_scale = 1000
z = d / depth_scale
x = (u - cx) * z / fx
y = (v - cy) * z / fy
world = (align @ pose @ np.array([x, y, z, 1]))
return world[:3] / world[3]
# Find the cloest point in the cloud with select
def find_closest_point(point, point_cloud, num=1):
# calculate the Euclidean distances between the input vector and each row of the matrix
distances = np.linalg.norm(point_cloud - point, axis=1)
# find the index of the row with the minimum distance
closest_index = np.argsort(distances)[:num]
# get the closest vector from the matrix
closest_vector = point_cloud[closest_index]
return closest_index, closest_vector
def plot_3d(xdata, ydata, zdata, color=None, b_min=2, b_max=8, view=(45, 45)):
fig, ax = plt.subplots(subplot_kw={"projection": "3d"}, dpi=200)
ax.view_init(view[0], view[1])
ax.set_xlim(b_min, b_max)
ax.set_ylim(b_min, b_max)
ax.set_zlim(b_min, b_max)
ax.scatter3D(xdata, ydata, zdata, c=color, cmap='rgb', s=0.1)
class SAM3DDemo(object):
def __init__(self, sam_model, sam_ckpt, scene_name):
sam = sam_model_registry[sam_model](checkpoint=sam_ckpt)
self.predictor = SamPredictor(sam)
self.scene_name = scene_name
scene_path = os.path.join('./scannet_data', scene_name)
self.color_path = os.path.join(scene_path, 'color')
self.depth_path = os.path.join(scene_path, 'depth')
self.pose_path = os.path.join(scene_path, 'pose')
self.intrinsic_path = os.path.join(scene_path, 'intrinsic')
self.align_matirx_path = f'{scene_path}/{scene_name}.txt'
self.img_ids = get_image_ids(self.color_path)
self.align_matrix = load_align_matrix_from_txt(self.align_matirx_path)
self.intrinsic_depth = load_matrix_from_txt(os.path.join(self.intrinsic_path, 'intrinsic_depth.txt'))
self.poses = [load_matrix_from_txt(os.path.join(self.pose_path, f'{i}.txt')) for i in self.img_ids]
self.rgb_images = [load_image(os.path.join(self.color_path, f'{i}.jpg')) for i in self.img_ids]
self.depth_images = [load_image(os.path.join(self.depth_path, f'{i}.png'))for i in self.img_ids]
def project_3D_to_images(self, select_points, valid_margin=20):
valid_img_ids = []
valid_points = {}
for img_i in range(len(self.img_ids)):
rgb_img = self.rgb_images[img_i]
depth_img = self.depth_images[img_i]
extrinsics = self.poses[img_i]
projection_matrix = self.intrinsic_depth @ np.linalg.inv(self.align_matrix @ extrinsics)
raw_points = np.vstack((select_points.T, np.ones((1, select_points.T.shape[1]))))
raw_points = np.dot(projection_matrix, raw_points)
# bounding simplest
points = raw_points[:2, :] / raw_points[2, :]
points = np.round(points).astype(np.int32)
valid = (points[0] >= valid_margin).all() & (points[1] >= valid_margin).all() \
& (points[0] < (rgb_img.shape[1] - valid_margin)).all() & (points[1] < (rgb_img.shape[0] - valid_margin)).all() \
& (raw_points[2, :] > 0).all()
if valid:
depth_margin = 0.4
gt_depths = depth_img[points[1], points[0]] / 1000
proj_depths = raw_points[2, :]
if (proj_depths[0] > (1 - depth_margin / 2.0) * gt_depths[0]) & (proj_depths[0] < (1 + depth_margin / 2.0) * gt_depths[0]):
valid_img_ids.append(img_i)
valid_points[img_i] = points
show_id = valid_img_ids[-1]
show_points = valid_points[show_id]
rgb_img = self.rgb_images[show_id]
fig, ax = plt.subplots()
ax.imshow(rgb_img)
for x, y in zip(show_points[0], show_points[1]):
ax.plot(x, y, 'ro')
canvas = fig.canvas
canvas.draw()
w, h = canvas.get_width_height()
rgb_img_w_points = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3)
print("projecting 3D point to images successfully...")
return valid_img_ids, valid_points, rgb_img_w_points
def process_img_w_sam(self, valid_img_ids, valid_points, granularity):
mask_colors = []
for img_i in range(len(self.img_ids)):
rgb_img = self.rgb_images[img_i]
msk_color = np.full(rgb_img.shape, 0.5)
if img_i in valid_img_ids:
self.predictor.set_image(rgb_img)
point_coor = valid_points[img_i].T[0][None]
masks, _, _ = self.predictor.predict(point_coords=point_coor, point_labels=np.array([1]))
# fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(10, 5))
# for i in range(3):
# mask_img = masks[i][:,:,None] * rgb_img
# axs[i].set_title(f'granularity {i}')
# axs[i].imshow(mask_img)
m = masks[granularity]
msk_color[m] = [0, 0, 1.0]
mask_colors.append(msk_color)
show_id = valid_img_ids[-1]
rgb_img = self.rgb_images[show_id]
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(24, 8))
for i in range(3):
mask_img = masks[i][:,:,None] * rgb_img
axs[i].set_title(f'granularity {i}')
axs[i].imshow(mask_img)
canvas = fig.canvas
canvas.draw()
w, h = canvas.get_width_height()
rgb_img_w_masks = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3)
print("processing images with SAM successfully...")
return mask_colors, rgb_img_w_masks
def project_mask_to_3d(self, mask_colors, sample_ratio=0.002):
x_data, y_data, z_data, c_data = [], [], [], []
for img_i in range(len(self.img_ids)):
id = self.img_ids[img_i]
# RGBD
d = self.depth_images[img_i]
c = self.rgb_images[img_i]
p = self.poses[img_i]
msk_color = mask_colors[img_i]
# Projecting RGB features into the point space
for i in range(d.shape[0]):
for j in range(d.shape[1]):
if random.random() < sample_ratio:
x, y, z = convert_from_uvd(j, i, d[i, j], self.intrinsic_depth, p, self.align_matrix)
if x is None:
continue
x_data.append(x)
y_data.append(y)
z_data.append(z)
ci = int(i * c.shape[0] / d.shape[0])
cj = int(j * c.shape[1] / d.shape[1])
c_data.append([msk_color[ci, cj]])
print("reprojecting images to 3D points successfully...")
return x_data, y_data, z_data, c_data
def match_projected_point_to_gt_point(self, x_data, y_data, z_data, c_data, gt_coords):
c_data = torch.tensor(np.concatenate(c_data, axis=0))
img_coords = np.array([x_data, y_data, z_data], dtype=np.float32).T
gt_quant_coords = np.floor_divide(gt_coords, 0.2)
img_quant_coords = np.floor_divide(img_coords, 0.2)
# Remove the reduandant coords
unique_gt_coords, gt_inverse_indices = np.unique(gt_quant_coords, axis=0, return_inverse=True)
unique_img_coords, img_inverse_indices = np.unique(img_quant_coords, axis=0, return_inverse=True)
# Match the coords in gt_coords to img_corrds
def find_loc(vec):
obj = np.empty((), dtype=object)
out = np.where((unique_img_coords == vec).all(1))[0]
obj[()] = out
return obj
gt_2_img_map = np.apply_along_axis(find_loc, 1, unique_gt_coords)
# Since some places are empty, using the simple round interplation
gt_2_img_map_filled = []
start_id = np.array([0])
for loc in gt_2_img_map:
if not np.any(loc):
loc = start_id
else:
start_id = loc
gt_2_img_map_filled.append(int(loc))
mean_colors = []
for i in range(unique_img_coords.shape[0]):
valid_locs = np.where(img_inverse_indices == i)
mean_f = torch.mean(c_data[valid_locs], axis=0)
# mean_f, _ = torch.mode(c_data[valid_locs], dim=0)
mean_colors.append(mean_f.unsqueeze(0))
mean_colors = torch.cat(mean_colors)
# Project the averaged features back to groundtruth point clouds
img_2_gt_colors = mean_colors[gt_2_img_map_filled]
projected_gt_colors = img_2_gt_colors[gt_inverse_indices]
print("convert projected points to GT points successfully...")
return projected_gt_colors
def render_point_cloud(self, data, color):
data_copy = copy.copy(data)
uint_color = torch.round(torch.tensor(color) * 255).to(torch.uint8)
data_copy['red'] = uint_color[:, 0]
data_copy['green'] = uint_color[:, 1]
data_copy['blue'] = uint_color[:, 2]
return data_copy
def run_with_coord(self, point, granularity):
x_data, y_data, z_data, c_data = [], [], [], []
plydata = PlyData.read(f"./scannet_data/{self.scene_name}/{self.scene_name}.ply")
data = plydata.elements[0].data
# gt_coords stand for the groudtruth point clouds coordinates
gt_coords = np.array([data['x'], data['y'], data['z']], dtype=np.float32).T
gt_color = np.array([data['red'], data['green'], data['blue']], dtype=np.float32).T
blank_color = np.full(gt_color.shape, 0.5)
select_index, select_points = find_closest_point(point, gt_coords, num=10)
point_select_color = blank_color.copy()
point_select_color[select_index] = [1.0, 0, 0]
data_point_select = self.render_point_cloud(data, point_select_color)
valid_img_ids, valid_points, rgb_img_w_points = self.project_3D_to_images(select_points)
mask_colors, rgb_img_w_masks = self.process_img_w_sam(valid_img_ids, valid_points, granularity)
x_data, y_data, z_data, c_data = self.project_mask_to_3d(mask_colors)
projected_gt_colors = self.match_projected_point_to_gt_point(x_data, y_data, z_data, c_data, gt_coords)
data_final = self.render_point_cloud(data, projected_gt_colors)
return data_point_select, rgb_img_w_points, rgb_img_w_masks, data_final |