Point-SAM / app.py
Jiayuan Gu
optimize performance
cdfd0d9
import dataclasses
import os
import hydra
import numpy as np
import torch
from flask import Flask, jsonify, request, render_template
from flask_cors import CORS
from omegaconf import OmegaConf
from safetensors.torch import load_model
from scipy.spatial.transform import Rotation
from point_sam import build_point_sam
import argparse
app = Flask(__name__, static_folder="static")
CORS(app)
MAX_POINT_ID = 100
point_info_id = 0
point_info_list = [None for _ in range(MAX_POINT_ID)]
@dataclasses.dataclass
class AuxInputs:
coords: torch.Tensor
features: torch.Tensor
centers: torch.Tensor
interp_index: torch.Tensor = None
interp_weight: torch.Tensor = None
def repeat_interleave(x: torch.Tensor, repeats: int, dim: int):
if repeats == 1:
return x
shape = list(x.shape)
shape.insert(dim + 1, 1)
shape[dim + 1] = repeats
x = x.unsqueeze(dim + 1).expand(shape).flatten(dim, dim + 1)
return x
class PointCloudProcessor:
def __init__(self, device="cuda", batch=True, return_tensors="pt"):
self.device = device
self.batch = batch
self.return_tensors = return_tensors
self.center = None
self.scale = None
def __call__(self, xyz: np.ndarray, rgb: np.ndarray):
# # The original data is z-up. Make it y-up.
# rot = Rotation.from_euler("x", -90, degrees=True)
# xyz = rot.apply(xyz)
if self.center is None or self.scale is None:
self.center = xyz.mean(0)
self.scale = np.max(np.linalg.norm(xyz - self.center, axis=-1))
xyz = (xyz - self.center) / self.scale
rgb = ((rgb / 255.0) - 0.5) * 2
if self.return_tensors == "np":
coords = np.float32(xyz)
feats = np.float32(rgb)
if self.batch:
coords = np.expand_dims(coords, 0)
feats = np.expand_dims(feats, 0)
elif self.return_tensors == "pt":
coords = torch.tensor(xyz, dtype=torch.float32, device=self.device)
feats = torch.tensor(rgb, dtype=torch.float32, device=self.device)
if self.batch:
coords = coords.unsqueeze(0)
feats = feats.unsqueeze(0)
else:
raise ValueError(self.return_tensors)
return coords, feats
def normalize(self, xyz):
return (xyz - self.center) / self.scale
class PointCloudSAMPredictor:
input_xyz: np.ndarray
input_rgb: np.ndarray
prompt_coords: list[tuple[float, float, float]]
prompt_labels: list[int]
coords: torch.Tensor
feats: torch.Tensor
pc_embedding: torch.Tensor
patches: dict[str, torch.Tensor]
prompt_mask: torch.Tensor
def __init__(self):
print("Created model")
model = build_point_sam("./model-2.safetensors")
model.pc_encoder.patch_embed.grouper.num_groups = 1024
model.pc_encoder.patch_embed.grouper.group_size = 128
if torch.cuda.is_available():
model = model.cuda()
model.eval()
self.model = model
self.input_rgb = None
self.input_xyz = None
self.input_processor = None
self.coords = None
self.feats = None
self.pc_embedding = None
self.patches = None
self.prompt_coords = None
self.prompt_labels = None
self.prompt_mask = None
self.candidate_index = 0
@torch.no_grad()
def set_pointcloud(self, xyz, rgb):
self.input_xyz = xyz
self.input_rgb = rgb
self.input_processor = PointCloudProcessor()
coords, feats = self.input_processor(xyz, rgb)
self.coords = coords
self.feats = feats
pc_embedding, patches = self.model.pc_encoder(self.coords, self.feats)
self.pc_embedding = pc_embedding
self.patches = patches
self.prompt_mask = None
def set_prompts(self, prompt_coords, prompt_labels):
self.prompt_coords = prompt_coords
self.prompt_labels = prompt_labels
@torch.no_grad()
def predict_mask(self):
normalized_prompt_coords = self.input_processor.normalize(
np.array(self.prompt_coords)
)
prompt_coords = torch.tensor(
normalized_prompt_coords, dtype=torch.float32, device="cuda"
)
prompt_labels = torch.tensor(
self.prompt_labels, dtype=torch.bool, device="cuda"
)
prompt_coords = prompt_coords.reshape(1, -1, 3)
prompt_labels = prompt_labels.reshape(1, -1)
multimask_output = prompt_coords.shape[1] == 1
# [B * M, num_outputs, num_points], [B * M, num_outputs]
def decode_masks(coords, feats, pc_embedding, patches, prompt_coords, prompt_labels, prompt_masks, multimask_output):
pc_embeddings, patches = pc_embedding, patches
centers = patches["centers"]
knn_idx = patches["knn_idx"]
coords = patches["coords"]
feats = patches["feats"]
aux_inputs = AuxInputs(coords=coords, features=feats, centers=centers)
pc_pe = self.model.point_encoder.pe_layer(centers)
sparse_embeddings = self.model.point_encoder(prompt_coords, prompt_labels)
dense_embeddings = self.model.mask_encoder(prompt_masks, coords, centers, knn_idx)
dense_embeddings = repeat_interleave(
dense_embeddings, sparse_embeddings.shape[0] // dense_embeddings.shape[0], 0
)
logits, iou_preds = self.model.mask_decoder(
pc_embeddings,
pc_pe,
sparse_embeddings,
dense_embeddings,
aux_inputs=aux_inputs,
multimask_output=multimask_output,
)
return logits, iou_preds
logits, scores = decode_masks(
self.coords,
self.feats,
self.pc_embedding,
self.patches,
prompt_coords,
prompt_labels,
self.prompt_mask[self.candidate_index].unsqueeze(0) if self.prompt_mask is not None else None,
multimask_output,
)
logits = logits.squeeze(0)
scores = scores.squeeze(0)
# if multimask_output:
# index = scores.argmax(0).item()
# logit = logits[index]
# else:
# logit = logits.squeeze(0)
# self.prompt_mask = logit.unsqueeze(0)
# pred_mask = logit > 0
# return pred_mask.cpu().numpy()
# Sort according to scores
_, indices = scores.sort(descending=True)
logits = logits[indices]
self.prompt_mask = logits # [num_outputs, num_points]
self.candidate_index = 0
return (logits > 0).cpu().numpy()
def set_candidate(self, index):
self.candidate_index = index
predictor = PointCloudSAMPredictor()
@app.route("/")
def index():
return app.send_static_file("index.html")
@app.route("/assets/<path:path>")
def assets_route(path):
print(path)
return app.send_static_file(f"assets/{path}")
@app.route("/hello_world", methods=["GET"])
def hello_world():
return "Hello, World!"
@app.route("/set_pointcloud", methods=["POST"])
def set_pointcloud():
request_data = request.get_json()
# print(request_data)
# print(type(request_data["points"]))
# print(type(request_data["colors"]))
xyz = request_data["points"]
xyz = np.array(xyz).reshape(-1, 3)
rgb = request_data["colors"]
rgb = np.array(list(rgb)).reshape(-1, 3)
predictor.set_pointcloud(xyz, rgb)
pc_embedding = predictor.pc_embedding.cpu()
patches = {"centers": predictor.patches["centers"].cpu(), "knn_idx": predictor.patches["knn_idx"].cpu(), "coords": predictor.coords.cpu(), "feats": predictor.feats.cpu()}
center = predictor.input_processor.center
scale = predictor.input_processor.scale
global point_info_id
global point_info_list
point_info_list[point_info_id] = {"pc_embedding": pc_embedding, "patches": patches, "center": center, "scale": scale, "prompt_mask": None}
return_msg = {"user_id": point_info_id}
point_info_id += 1
return jsonify(return_msg)
@app.route("/set_candidate", methods=["POST"])
def set_candidate():
request_data = request.get_json()
candidate_index = request_data["index"]
predictor.set_candidate(candidate_index)
return "success"
def visualize_pcd_with_prompts(xyz, rgb, prompt_coords, prompt_labels):
import trimesh
pcd = trimesh.PointCloud(xyz, rgb)
prompt_spheres = []
for i, coord in enumerate(prompt_coords):
sphere = trimesh.creation.icosphere()
sphere.apply_scale(0.02)
sphere.apply_translation(coord)
sphere.visual.vertex_colors = [255, 0, 0] if prompt_labels[i] else [0, 255, 0]
prompt_spheres.append(sphere)
return trimesh.Scene([pcd] + prompt_spheres)
@app.route("/set_prompts", methods=["POST"])
def set_prompts():
global point_info_list
request_data = request.get_json()
print(request_data.keys())
# [n_prompts, 3]
prompt_coords = request_data["prompt_coords"]
# [n_prompts]. 0 for negative, 1 for positive
prompt_labels = request_data["prompt_labels"]
user_id = request_data["user_id"]
print(user_id)
point_info = point_info_list[user_id]
predictor.pc_embedding = point_info["pc_embedding"].cuda()
patches = point_info["patches"]
predictor.patches = {"centers": patches["centers"].cuda(), "knn_idx": patches["knn_idx"].cuda(), "coords": patches["coords"].cuda(), "feats": patches["feats"].cuda()}
predictor.input_processor.center = point_info["center"]
predictor.input_processor.scale = point_info["scale"]
if point_info["prompt_mask"] is not None:
predictor.prompt_mask = point_info["prompt_mask"].cuda()
else:
predictor.prompt_mask = None
# instance_id = request_data["instance_id"] # int
if len(prompt_coords) == 0:
predictor.prompt_mask = None
pred_mask = np.zeros([len(prompt_coords)], dtype=np.bool_)
return jsonify({"mask": pred_mask.tolist()})
predictor.set_prompts(prompt_coords, prompt_labels)
pred_mask = predictor.predict_mask()
point_info_list[user_id]["prompt_mask"] = predictor.prompt_mask.cpu()
# # Visualize
# xyz = predictor.coords.cpu().numpy()[0]
# rgb = predictor.feats.cpu().numpy()[0] * 0.5 + 0.5
# prompt_coords = predictor.input_processor.normalize(np.array(predictor.prompt_coords))
# scene = visualize_pcd_with_prompts(xyz, rgb, prompt_coords, predictor.prompt_labels)
# scene.show()
return jsonify({"mask": pred_mask.tolist()})
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
app.run(host=args.host, port=args.port, debug=True)