Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
import os | |
import ast | |
import time | |
import random | |
from PIL import Image | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
from plyfile import PlyData | |
import gradio as gr | |
import plotly.graph_objs as go | |
from sam_3d import SAM3DDemo | |
def pc_to_plot(pc): | |
return go.Figure( | |
data=[ | |
go.Scatter3d( | |
x=pc['x'], y=pc['y'], z=pc['z'], | |
mode='markers', | |
marker=dict( | |
size=2, | |
color=['rgb({},{},{})'.format(r,g,b) for r,g,b in zip(pc['red'], pc['green'], pc['blue'])], | |
) | |
) | |
], | |
layout=dict( | |
scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)) | |
), | |
) | |
def inference(scene_name, granularity, coords, plot): | |
print(scene_name, coords) | |
sam_3d = SAM3DDemo('vit_b', 'sam_vit_b_01ec64.pth', scene_name) | |
coords = ast.literal_eval(coords) | |
data_point_select, rgb_img_w_points, rgb_img_w_masks, data_final = sam_3d.run_with_coord(coords, int(granularity)) | |
return pc_to_plot(data_point_select), Image.fromarray(rgb_img_w_points), Image.fromarray(rgb_img_w_masks), pc_to_plot(data_final) | |
plydatas = [] | |
for scene_name in ['scene0000_00', 'scene0001_00', 'scene0002_00']: | |
plydata = PlyData.read(f"./scannet_data/{scene_name}/{scene_name}.ply") | |
data = plydata.elements[0].data | |
plydatas.append(data) | |
examples = [['scene0000_00', 0, [0, -2.5, 0.7], pc_to_plot(plydatas[0])], | |
['scene0001_00', 0, [0, -2.5, 1], pc_to_plot(plydatas[1])], | |
['scene0002_00', 0, [0, -2.5, 1], pc_to_plot(plydatas[2])],] | |
title = 'Segment_Anything on 3D in-door point clouds' | |
description = """ | |
Gradio Demo for Segment Anything on 3D indoor scenes (ScanNet supported). \n | |
The logic is straighforward: 1) Find a point in 3D; 2) project the 3D point to valid images; 3) perform 2D SAM on valid images; 4) reproject 2D results back to 3D; 5) Visualization. | |
Unfortunatly, it does not support click the point cloud to generate coordinates automatically. You may want to write down the coordinates and put it manually. \n | |
""" | |
article = """ | |
<p style='text-align: center'> | |
<a href='https://arxiv.org/abs/2210.04150' target='_blank'> | |
Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP | |
</a> | |
| | |
<a href='https://github.com/facebookresearch/ov-seg' target='_blank'>Github Repo</a></p> | |
""" | |
gr.Interface( | |
inference, | |
inputs=[ | |
gr.Dropdown(choices=['scene0000_00', 'scene0001_00', 'scene0002_00'], label="Scannet scene name (limited scenes supported)"), | |
gr.Dropdown(choices=[0, 1, 2], label="Mask granularity from 0 (most coarse) to 2 (most precise)"), | |
gr.Textbox(lines=1, label='Coordinates'), | |
gr.Plot(label="Input Point cloud (For visualization and point finding only, click responce not supported yet.)"), | |
], | |
outputs=[gr.Plot(label='Selected point(s): red points show the top 10 cloest points for your input anchor point'), | |
gr.Image(label='Selected image with projected points'), | |
gr.Image(label='Selected image processed after SAM'), | |
gr.Plot(label='Output Point cloud: blue points represent the mask')], | |
title=title, | |
description=description, | |
article=article, | |
examples=examples).launch() |