sam_3d / app.py
JeffLiang
fix stuck when perform inference
1a6f9ac
raw
history blame
3.6 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import os
import ast
import time
import random
from PIL import Image
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from plyfile import PlyData
import gradio as gr
import plotly.graph_objs as go
from sam_3d import SAM3DDemo
def pc_to_plot(pc):
return go.Figure(
data=[
go.Scatter3d(
x=pc['x'], y=pc['y'], z=pc['z'],
mode='markers',
marker=dict(
size=2,
color=['rgb({},{},{})'.format(r,g,b) for r,g,b in zip(pc['red'], pc['green'], pc['blue'])],
)
)
],
layout=dict(
scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False))
),
)
def inference(scene_name, granularity, coords, plot):
print(scene_name, coords)
sam_3d = SAM3DDemo('vit_b', 'sam_vit_b_01ec64.pth', scene_name)
coords = ast.literal_eval(coords)
data_point_select, rgb_img_w_points, rgb_img_w_masks, data_final = sam_3d.run_with_coord(coords, int(granularity))
return pc_to_plot(data_point_select), Image.fromarray(rgb_img_w_points), Image.fromarray(rgb_img_w_masks), pc_to_plot(data_final)
plydatas = []
for scene_name in ['scene0000_00', 'scene0001_00', 'scene0002_00']:
plydata = PlyData.read(f"./scannet_data/{scene_name}/{scene_name}.ply")
data = plydata.elements[0].data
plydatas.append(data)
examples = [['scene0000_00', 0, [0, -2.5, 0.7], pc_to_plot(plydatas[0])],
['scene0001_00', 0, [1.9, 1.1, 0.5], pc_to_plot(plydatas[1])],
['scene0002_00', 0, [0.58, 0.47, 0.25], pc_to_plot(plydatas[2])],]
title = 'Segment_Anything on 3D in-door point clouds'
description = """
Gradio Demo for Segmenting Anything on 3D Indoor Scenes (ScanNet supported). \n
The logic is straightforward: 1) Find a point in 3D. 2) Project the 3D point onto valid images. 3) Perform 2D SAM on the valid images. 4) Reproject the 2D results back to 3D. 5) Visualize the results. \n
Unfortunately, this demo does not support automatically generating coordinates by clicking on the point cloud. You may need to manually write down the coordinates and input them. \n
Play with the examples below first and try to modify the coordinates and mask granularity. \n
"""
article = """
<p style='text-align: center'>
<a href='https://arxiv.org/abs/2210.04150' target='_blank'>
Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP
</a>
|
<a href='https://github.com/facebookresearch/ov-seg' target='_blank'>Github Repo</a></p>
"""
gr.Interface(
inference,
inputs=[
gr.Dropdown(choices=['scene0000_00', 'scene0001_00', 'scene0002_00'], label="Scannet scene name (limited scenes supported)"),
gr.Dropdown(choices=[0, 1, 2], label="Mask granularity from 0 (most coarse) to 2 (most precise)"),
gr.Textbox(lines=1, label='Coordinates'),
gr.Plot(label="Input Point cloud (For visualization and point finding only, click responce not supported yet.)"),
],
outputs=[gr.Plot(label='Selected point(s): red points show the top 10 cloest points for your input anchor point'),
gr.Image(label='Selected image with projected points'),
gr.Image(label='Selected image processed after SAM'),
gr.Plot(label='Output Point cloud: blue points represent the mask')],
title=title,
description=description,
article=article,
examples=examples).queue().launch()