sam_3d / app.py
JeffLiang
init
fcdbf88
raw
history blame
3.44 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import os
import ast
import time
import random
from PIL import Image
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from plyfile import PlyData
import gradio as gr
import plotly.graph_objs as go
from sam_3d import SAM3DDemo
def pc_to_plot(pc):
return go.Figure(
data=[
go.Scatter3d(
x=pc['x'], y=pc['y'], z=pc['z'],
mode='markers',
marker=dict(
size=2,
color=['rgb({},{},{})'.format(r,g,b) for r,g,b in zip(pc['red'], pc['green'], pc['blue'])],
)
)
],
layout=dict(
scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False))
),
)
def inference(scene_name, granularity, coords, plot):
print(scene_name, coords)
sam_3d = SAM3DDemo('vit_b', 'sam_vit_b_01ec64.pth', scene_name)
coords = ast.literal_eval(coords)
data_point_select, rgb_img_w_points, rgb_img_w_masks, data_final = sam_3d.run_with_coord(coords, int(granularity))
return pc_to_plot(data_point_select), Image.fromarray(rgb_img_w_points), Image.fromarray(rgb_img_w_masks), pc_to_plot(data_final)
plydatas = []
for scene_name in ['scene0000_00', 'scene0001_00', 'scene0002_00']:
plydata = PlyData.read(f"./scannet_data/{scene_name}/{scene_name}.ply")
data = plydata.elements[0].data
plydatas.append(data)
examples = [['scene0000_00', 0, [0, -2.5, 0.7], pc_to_plot(plydatas[0])],
['scene0001_00', 0, [0, -2.5, 1], pc_to_plot(plydatas[1])],
['scene0002_00', 0, [0, -2.5, 1], pc_to_plot(plydatas[2])],]
title = 'Segment_Anything on 3D in-door point clouds'
description = """
Gradio Demo for Segment Anything on 3D indoor scenes (ScanNet supported). \n
The logic is straighforward: 1) Find a point in 3D; 2) project the 3D point to valid images; 3) perform 2D SAM on valid images; 4) reproject 2D results back to 3D; 5) Visualization.
Unfortunatly, it does not support click the point cloud to generate coordinates automatically. You may want to write down the coordinates and put it manually. \n
"""
article = """
<p style='text-align: center'>
<a href='https://arxiv.org/abs/2210.04150' target='_blank'>
Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP
</a>
|
<a href='https://github.com/facebookresearch/ov-seg' target='_blank'>Github Repo</a></p>
"""
gr.Interface(
inference,
inputs=[
gr.Dropdown(choices=['scene0000_00', 'scene0001_00', 'scene0002_00'], label="Scannet scene name (limited scenes supported)"),
gr.Dropdown(choices=[0, 1, 2], label="Mask granularity from 0 (most coarse) to 2 (most precise)"),
gr.Textbox(lines=1, label='Coordinates'),
gr.Plot(label="Input Point cloud (For visualization and point finding only, click responce not supported yet.)"),
],
outputs=[gr.Plot(label='Selected point(s): red points show the top 10 cloest points for your input anchor point'),
gr.Image(label='Selected image with projected points'),
gr.Image(label='Selected image processed after SAM'),
gr.Plot(label='Output Point cloud: blue points represent the mask')],
title=title,
description=description,
article=article,
examples=examples).launch()