Spaces:
Running
Running
import os | |
import sys | |
import json | |
import uuid | |
import numpy as np | |
import gradio as gr | |
import trimesh | |
import zipfile | |
import subprocess | |
from datetime import datetime | |
from functools import partial | |
from PIL import Image, ImageChops | |
from huggingface_hub import snapshot_download | |
from gradio_model3dcolor import Model3DColor | |
from gradio_model3dnormal import Model3DNormal | |
is_local_run = os.path.exists("../SpaRP_API") | |
code_dir = snapshot_download("sudo-ai/SpaRP_API", token=os.environ['HF_TOKEN']) if not is_local_run else "../SpaRP_API" | |
if not is_local_run: | |
zip_file_path = f'{code_dir}/examples.zip' | |
# Unzipping the file into the current directory | |
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: | |
zip_ref.extractall(os.getcwd()) | |
with open(f'{code_dir}/api.json', 'r') as file: | |
api_dict = json.load(file) | |
SEGM_i_CALL = api_dict["SEGM_i_CALL"] | |
SEGM_CALL = api_dict["SEGM_CALL"] | |
UNPOSED_CALL = api_dict["UNPOSED_CALL"] | |
MESH_CALL = api_dict["MESH_CALL"] | |
_TITLE = ( | |
"""SpaRP: Fast 3D Object Reconstruction and Pose Estimation from Sparse Views""" | |
) | |
_DESCRIPTION = ( | |
"""Try SpaRP to reconstruct 3D textured mesh from one or a few unposed images!""" | |
) | |
_PR = """ | |
<div> | |
<b><em>Check out <a href="https://www.sudo.ai/3dgen">Hillbot (sudoAI)</a> for more details and advanced features.</em></b> | |
</div> | |
""" | |
STYLE = """ | |
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous"> | |
<style> | |
.alert, .alert div, .alert b { | |
color: black !important; | |
} | |
</style> | |
""" | |
# info (info-circle-fill), cursor (hand-index-thumb), wait (hourglass-split), done (check-circle) | |
ICONS = { | |
"info": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#0d6efd" class="bi bi-info-circle-fill flex-shrink-0 me-2" viewBox="0 0 16 16"> | |
<path d="M8 16A8 8 0 1 0 8 0a8 8 0 0 0 0 16zm.93-9.412-1 4.705c-.07.34.029.533.304.533.194 0 .487-.07.686-.246l-.088.416c-.287.346-.92.598-1.465.598-.703 0-1.002-.422-.808-1.319l.738-3.468c.064-.293.006-.399-.287-.47l-.451-.081.082-.381 2.29-.287zM8 5.5a1 1 0 1 1 0-2 1 1 0 0 1 0 2z"/> | |
</svg>""", | |
"cursor": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#0dcaf0" class="bi bi-hand-index-thumb-fill flex-shrink-0 me-2" viewBox="0 0 16 16"> | |
<path d="M8.5 1.75v2.716l.047-.002c.312-.012.742-.016 1.051.046.28.056.543.18.738.288.273.152.456.385.56.642l.132-.012c.312-.024.794-.038 1.158.108.37.148.689.487.88.716.075.09.141.175.195.248h.582a2 2 0 0 1 1.99 2.199l-.272 2.715a3.5 3.5 0 0 1-.444 1.389l-1.395 2.441A1.5 1.5 0 0 1 12.42 16H6.118a1.5 1.5 0 0 1-1.342-.83l-1.215-2.43L1.07 8.589a1.517 1.517 0 0 1 2.373-1.852L5 8.293V1.75a1.75 1.75 0 0 1 3.5 0z"/> | |
</svg>""", | |
"wait": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#6c757d" class="bi bi-hourglass-split flex-shrink-0 me-2" viewBox="0 0 16 16"> | |
<path d="M2.5 15a.5.5 0 1 1 0-1h1v-1a4.5 4.5 0 0 1 2.557-4.06c.29-.139.443-.377.443-.59v-.7c0-.213-.154-.451-.443-.59A4.5 4.5 0 0 1 3.5 3V2h-1a.5.5 0 0 1 0-1h11a.5.5 0 0 1 0 1h-1v1a4.5 4.5 0 0 1-2.557 4.06c-.29.139-.443.377-.443.59v.7c0 .213.154.451.443.59A4.5 4.5 0 0 1 12.5 13v1h1a.5.5 0 0 1 0 1h-11zm2-13v1c0 .537.12 1.045.337 1.5h6.326c.216-.455.337-.963.337-1.5V2h-7zm3 6.35c0 .701-.478 1.236-1.011 1.492A3.5 3.5 0 0 0 4.5 13s.866-1.299 3-1.48V8.35zm1 0v3.17c2.134.181 3 1.48 3 1.48a3.5 3.5 0 0 0-1.989-3.158C8.978 9.586 8.5 9.052 8.5 8.351z"/> | |
</svg>""", | |
"done": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#198754" class="bi bi-check-circle-fill flex-shrink-0 me-2" viewBox="0 0 16 16"> | |
<path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0zm-3.97-3.03a.75.75 0 0 0-1.08.022L7.477 9.417 5.384 7.323a.75.75 0 0 0-1.06 1.06L6.97 11.03a.75.75 0 0 0 1.079-.02l3.992-4.99a.75.75 0 0 0-.01-1.05z"/> | |
</svg>""", | |
} | |
icons2alert = { | |
"info": "primary", # blue | |
"cursor": "info", # light blue | |
"wait": "secondary", # gray | |
"done": "success", # green | |
} | |
def message(text, icon_type="info"): | |
return f"""{STYLE} <div class="alert alert-{icons2alert[icon_type]} d-flex align-items-center" role="alert"> {ICONS[icon_type]} | |
<div> | |
{text} | |
</div> | |
</div>""" | |
def create_tmp_dir(): | |
tmp_dir = ( | |
"../demo_exp/" | |
+ datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
+ "_" | |
+ str(uuid.uuid4())[:4] | |
) | |
os.makedirs(tmp_dir, exist_ok=True) | |
print("create tmp_exp_dir", tmp_dir) | |
return tmp_dir | |
def preprocess_imgs(tmp_dir, input_img, idx=None): | |
if isinstance(input_img, list) and idx is None: | |
for i, img_tuple in enumerate(input_img): | |
Image.open(img_tuple[0]).save(f"{tmp_dir}/input_{i}.png") | |
os.system(SEGM_i_CALL.replace("{tmp_dir}", tmp_dir).replace("{i}", str(i))) | |
return [Image.open(f"{tmp_dir}/seg_{i}.png") for i in range(len(input_img))] | |
if idx is not None: | |
print("image idx:", int(idx)) | |
input_img = Image.open(input_img[int(idx)][0]) | |
input_img.save(f"{tmp_dir}/input.png") | |
os.system(SEGM_CALL.replace("{tmp_dir}", tmp_dir)) | |
processed_img = Image.open(f"{tmp_dir}/seg.png") | |
return processed_img.resize((320, 320), Image.Resampling.LANCZOS) | |
def ply_to_glb(ply_path): | |
script_path = f"{code_dir}/ply2glb.py" | |
result = subprocess.run( | |
["python", script_path, "--", ply_path], | |
capture_output=True, | |
text=True, | |
) | |
print("Output of blender script:") | |
print(result.stdout) | |
glb_path = ply_path.replace(".ply", ".glb") | |
return glb_path | |
def mesh_gen(tmp_dir, use_seg): | |
os.system(UNPOSED_CALL.replace("{tmp_dir}", tmp_dir).replace("{use_seg}", str(use_seg))) | |
os.system(MESH_CALL.replace("{tmp_dir}", tmp_dir)) | |
mesh = trimesh.load_mesh(f"{tmp_dir}/mesh.ply") | |
vertex_normals = mesh.vertex_normals | |
colors = (-vertex_normals + 1) / 2.0 | |
colors = (colors * 255).astype(np.uint8) # Convert to 8-bit color | |
mesh.visual.vertex_colors = colors | |
mesh.export(f"{tmp_dir}/mesh_normal.ply", file_type="ply") | |
color_path = ply_to_glb(f"{tmp_dir}/mesh.ply") | |
normal_path = ply_to_glb(f"{tmp_dir}/mesh_normal.ply") | |
return color_path, normal_path | |
def feed_example_to_gallery(img): | |
for display_img in display_imgs: | |
display_img = display_img[0] | |
diff = ImageChops.difference(img, display_img) | |
if not diff.getbbox(): # two images are the same | |
img_id = display_img.filename | |
data_dir = os.path.join(data_folder, str(img_id)) | |
data_fns = os.listdir(data_dir) | |
data_fns.sort() | |
data_imgs = [] | |
for data_fn in data_fns: | |
file_path = os.path.join(data_dir, data_fn) | |
img = Image.open(file_path) | |
data_imgs.append(img) | |
return data_imgs | |
return [img] | |
custom_theme = gr.themes.Soft(primary_hue="blue").set( | |
button_secondary_background_fill="*neutral_100", | |
button_secondary_background_fill_hover="*neutral_200", | |
) | |
# Gradio blocks | |
with gr.Blocks(title=_TITLE, css="style.css", theme=custom_theme) as demo: | |
tmp_dir_unposed = gr.State("./demo_exp/placeholder") | |
display_folder = os.path.join(os.path.dirname(__file__), "examples_display") | |
display_fns = os.listdir(display_folder) | |
display_fns.sort() | |
display_imgs = [] | |
for i, display_fn in enumerate(display_fns): | |
file_path = os.path.join(display_folder, display_fn) | |
img = Image.open(file_path) | |
img.filename = i | |
display_imgs.append([img]) | |
data_folder = os.path.join(os.path.dirname(__file__), "examples_data") | |
# UI | |
with gr.Row(): | |
gr.Markdown("# " + _TITLE) | |
with gr.Row(): | |
gr.Markdown("### " + _DESCRIPTION) | |
with gr.Row(): | |
gr.Markdown(_PR) | |
with gr.Row(): | |
guide_text = gr.HTML( | |
message("Input image(s) of object that you want to generate mesh with.") | |
) | |
with gr.Row(variant="panel"): | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(scale=5): | |
input_gallery = gr.Gallery( | |
label="Input Images", | |
show_label=False, | |
columns=[3], | |
rows=[2], | |
object_fit="contain", | |
height=400, | |
show_share_button=False, | |
) | |
input_image = gr.Image( | |
type="pil", | |
image_mode="RGBA", | |
visible=False, | |
) | |
with gr.Column(scale=5): | |
processed_gallery = gr.Gallery( | |
label="Background Removal", | |
columns=[3], | |
rows=[2], | |
object_fit="contain", | |
height=400, | |
interactive=False, | |
show_share_button=False, | |
) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
example = gr.Examples( | |
examples=display_imgs, | |
inputs=[input_image], | |
outputs=[input_gallery], | |
fn=feed_example_to_gallery, | |
label="Image Examples (Click one of the images below to start)", | |
examples_per_page=10, | |
run_on_click=True, | |
) | |
with gr.Column(scale=5): | |
with gr.Row(): | |
bg_removed_checkbox = gr.Checkbox( | |
value=True, | |
label="Use background removed images (uncheck to use original)", | |
interactive=True, | |
) | |
with gr.Row(): | |
run_btn = gr.Button( | |
"Generate", | |
variant="primary", | |
interactive=False, | |
) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
mesh_output = Model3DColor( | |
label="Generated Mesh (color)", | |
elem_id="mesh-out", | |
height=400, | |
) | |
with gr.Column(scale=5): | |
mesh_output_normal = Model3DNormal( | |
label="Generated Mesh (normal)", | |
elem_id="mesh-normal-out", | |
height=400, | |
) | |
# Callbacks | |
disable_button = lambda: gr.Button(interactive=False) | |
enable_button = lambda: gr.Button(interactive=True) | |
update_guide = lambda GUIDE_TEXT, icon_type="info": gr.HTML( | |
value=message(GUIDE_TEXT, icon_type) | |
) | |
def is_cleared(content): | |
if content: | |
raise ValueError # gr.Error(visible=False) doesn't work, trick for not showing error message | |
def not_cleared(content): | |
if not content: | |
raise ValueError # gr.Error(visible=False) doesn't work, trick for not showing error message | |
# Upload event listener for input gallery | |
input_gallery.upload( | |
fn=disable_button, | |
outputs=[run_btn], | |
queue=False, | |
).success( | |
fn=create_tmp_dir, | |
outputs=[tmp_dir_unposed], | |
queue=False, | |
).success( | |
fn=partial( | |
update_guide, "Removing background of the input image(s)...", "wait" | |
), | |
outputs=[guide_text], | |
queue=False, | |
).success( | |
fn=preprocess_imgs, | |
inputs=[tmp_dir_unposed, input_gallery], | |
outputs=[processed_gallery], | |
queue=True, | |
).success( | |
fn=partial(update_guide, "Click <b>Generate</b> to generate mesh.", "cursor"), | |
outputs=[guide_text], | |
queue=False, | |
).success( | |
fn=enable_button, | |
outputs=[run_btn], | |
queue=False, | |
) | |
# Clear event listener for input gallery | |
input_gallery.change( | |
fn=is_cleared, | |
inputs=[input_gallery], | |
queue=False, | |
).success( | |
fn=disable_button, | |
outputs=[run_btn], | |
queue=False, | |
).success( | |
fn=lambda: None, | |
outputs=[input_image], | |
queue=False, | |
).success( | |
fn=lambda: None, | |
outputs=[processed_gallery], | |
queue=False, | |
).success( | |
fn=lambda: None, | |
outputs=[mesh_output], | |
queue=False, | |
).success( | |
fn=lambda: None, | |
outputs=[mesh_output_normal], | |
queue=False, | |
).success( | |
fn=partial( | |
update_guide, | |
"Input image(s) of object that you want to generate mesh with.", | |
"info", | |
), | |
outputs=[guide_text], | |
queue=False, | |
) | |
# Change event listener for input image | |
input_image.change( | |
fn=not_cleared, | |
inputs=[input_image], | |
queue=False, | |
).success( | |
fn=disable_button, | |
outputs=run_btn, | |
queue=False, | |
).success( | |
fn=lambda: None, | |
outputs=[mesh_output], | |
queue=False, | |
).success( | |
fn=lambda: None, | |
outputs=[mesh_output_normal], | |
queue=False, | |
).success( | |
fn=create_tmp_dir, | |
outputs=tmp_dir_unposed, | |
queue=False, | |
).success( | |
fn=partial( | |
update_guide, "Removing background of the input image(s)...", "wait" | |
), | |
outputs=[guide_text], | |
queue=False, | |
).success( | |
fn=preprocess_imgs, | |
inputs=[tmp_dir_unposed, input_gallery], | |
outputs=[processed_gallery], | |
queue=True, | |
).success( | |
fn=partial(update_guide, "Click <b>Generate</b> to generate mesh.", "cursor"), | |
outputs=[guide_text], | |
queue=False, | |
).success( | |
fn=enable_button, | |
outputs=run_btn, | |
queue=False, | |
) | |
# Click event listener for run button | |
run_btn.click( | |
fn=disable_button, | |
outputs=[run_btn], | |
queue=False, | |
).success( | |
fn=lambda: None, | |
outputs=[mesh_output], | |
queue=False, | |
).success( | |
fn=lambda: None, | |
outputs=[mesh_output_normal], | |
queue=False, | |
).success( | |
fn=partial(update_guide, "Generating the mesh...", "wait"), | |
outputs=[guide_text], | |
queue=False, | |
).success( | |
fn=mesh_gen, | |
inputs=[tmp_dir_unposed, bg_removed_checkbox], | |
outputs=[mesh_output, mesh_output_normal], | |
queue=True, | |
).success( | |
fn=partial( | |
update_guide, | |
"Successfully generated the mesh. (It might take a few seconds to load the mesh)", | |
"done", | |
), | |
outputs=[guide_text], | |
queue=False, | |
).success( | |
fn=enable_button, | |
outputs=[run_btn], | |
queue=False, | |
) | |
demo.queue().launch( | |
debug=False, | |
share=False, | |
inline=False, | |
show_api=False, | |
server_name="0.0.0.0", | |
) |