ImgRoboAssetGen / app.py
xinjie.wang
update
0fd2448
import os
os.environ["GRADIO_APP"] = "imageto3d"
import gradio as gr
from common import (
MAX_SEED,
VERSION,
active_btn_by_content,
end_session,
extract_3d_representations_v2,
extract_urdf,
get_seed,
image_css,
image_to_3d,
lighting_css,
preprocess_image_fn,
preprocess_sam_image_fn,
select_point,
start_session,
)
from gradio.themes import Default
from gradio.themes.utils.colors import slate
with gr.Blocks(
delete_cache=(43200, 43200), theme=Default(primary_hue=slate)
) as demo:
gr.Markdown(
f"""
## Image to 3D Asset Pipeline \n
version: {VERSION} \n
<!-- The service is temporarily deployed on `dev015-10.34.8.82: CUDA 4`. -->
"""
)
gr.HTML(image_css)
gr.HTML(lighting_css)
with gr.Row():
with gr.Column(scale=2):
with gr.Tabs() as input_tabs:
with gr.Tab(
label="Image(auto seg)", id=0
) as single_image_input_tab:
raw_image_cache = gr.Image(
format="png",
image_mode="RGB",
type="pil",
visible=False,
)
image_prompt = gr.Image(
label="Input Image",
format="png",
image_mode="RGBA",
type="pil",
height=400,
elem_classes=["image_fit"],
)
gr.Markdown(
"""
If you are not satisfied with the auto segmentation
result, please switch to the `Image(SAM seg)` tab."""
)
with gr.Tab(
label="Image(SAM seg)", id=1
) as samimage_input_tab:
with gr.Row():
with gr.Column(scale=1):
image_prompt_sam = gr.Image(
label="Input Image",
type="numpy",
height=400,
elem_classes=["image_fit"],
)
image_seg_sam = gr.Image(
label="SAM Seg Image",
image_mode="RGBA",
type="pil",
height=400,
visible=False,
)
with gr.Column(scale=1):
image_mask_sam = gr.AnnotatedImage(
elem_classes=["image_fit"]
)
fg_bg_radio = gr.Radio(
["foreground_point", "background_point"],
label="Select foreground(green) or background(red) points, by default foreground", # noqa
value="foreground_point",
)
gr.Markdown(
""" Click the `Input Image` to select SAM points,
after get the satisified segmentation, click `Generate`
button to generate the 3D asset. \n
Note: If the segmented foreground is too small relative
to the entire image area, the generation will fail.
"""
)
with gr.Accordion(label="Generation Settings", open=False):
with gr.Row():
seed = gr.Slider(
0, MAX_SEED, label="Seed", value=0, step=1
)
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize Seed", value=False
)
project_delight = gr.Checkbox(
label="Backproject delighting",
value=True,
)
gr.Markdown("Geo Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(
0.0,
10.0,
label="Guidance Strength",
value=7.5,
step=0.1,
)
ss_sampling_steps = gr.Slider(
1, 50, label="Sampling Steps", value=12, step=1
)
gr.Markdown("Visual Appearance Generation")
with gr.Row():
slat_guidance_strength = gr.Slider(
0.0,
10.0,
label="Guidance Strength",
value=3.0,
step=0.1,
)
slat_sampling_steps = gr.Slider(
1, 50, label="Sampling Steps", value=12, step=1
)
generate_btn = gr.Button(
"Generate(~0.5 mins)", variant="primary", interactive=False
)
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
with gr.Row():
extract_rep3d_btn = gr.Button(
"Extract 3D Representation(~2 mins)",
variant="primary",
interactive=False,
)
with gr.Accordion(
label="Enter Asset Attributes(optional)", open=False
):
asset_cat_text = gr.Textbox(
label="Enter Asset Category (e.g., chair)"
)
height_range_text = gr.Textbox(
label="Enter Height Range in meter (e.g., 0.5-0.6)"
)
mass_range_text = gr.Textbox(
label="Enter Mass Range in kg (e.g., 1.1-1.2)"
)
asset_version_text = gr.Textbox(
label=f"Enter version (e.g., {VERSION})"
)
with gr.Row():
extract_urdf_btn = gr.Button(
"Extract URDF with physics(~1 mins)",
variant="primary",
interactive=False,
)
with gr.Row():
gr.Markdown(
"#### Estimated Asset 3D Attributes(No input required)"
)
with gr.Row():
est_type_text = gr.Textbox(
label="Asset category", interactive=False
)
est_height_text = gr.Textbox(
label="Real height(.m)", interactive=False
)
est_mass_text = gr.Textbox(
label="Mass(.kg)", interactive=False
)
est_mu_text = gr.Textbox(
label="Friction coefficient", interactive=False
)
with gr.Row():
download_urdf = gr.DownloadButton(
label="Download URDF", variant="primary", interactive=False
)
gr.Markdown(
""" NOTE: If `Asset Attributes` are provided, the provided
properties will be used; otherwise, the GPT-preset properties
will be applied. \n
The `Download URDF` file is restored to the real scale and
has quality inspection, open with an editor to view details.
"""
)
with gr.Row() as single_image_example:
examples = gr.Examples(
label="Image Gallery",
examples=[
[f"assets/example_image/{image}"]
for image in os.listdir(
"assets/example_image"
)
],
inputs=[image_prompt],
fn=preprocess_image_fn,
outputs=[image_prompt, raw_image_cache],
run_on_click=True,
examples_per_page=10,
)
with gr.Row(visible=False) as single_sam_image_example:
examples = gr.Examples(
label="Image Gallery",
examples=[
f"assets/example_image/{image}"
for image in os.listdir(
"assets/example_image"
)
],
inputs=[image_prompt_sam],
fn=preprocess_sam_image_fn,
outputs=[image_prompt_sam, raw_image_cache],
run_on_click=True,
examples_per_page=10,
)
with gr.Column(scale=1):
video_output = gr.Video(
label="Generated 3D Asset",
autoplay=True,
loop=True,
height=300,
)
model_output_gs = gr.Model3D(
label="Gaussian Representation", height=300, interactive=False
)
aligned_gs = gr.Textbox(visible=False)
gr.Markdown(
""" The rendering of `Gaussian Representation` takes additional 10s. """ # noqa
)
with gr.Row():
model_output_mesh = gr.Model3D(
label="Mesh Representation",
height=300,
interactive=False,
clear_color=[0.8, 0.8, 0.8, 1],
elem_id="lighter_mesh",
)
is_samimage = gr.State(False)
output_buf = gr.State()
selected_points = gr.State(value=[])
demo.load(start_session)
demo.unload(end_session)
single_image_input_tab.select(
lambda: tuple(
[False, gr.Row.update(visible=True), gr.Row.update(visible=False)]
),
outputs=[is_samimage, single_image_example, single_sam_image_example],
)
samimage_input_tab.select(
lambda: tuple(
[True, gr.Row.update(visible=True), gr.Row.update(visible=False)]
),
outputs=[is_samimage, single_sam_image_example, single_image_example],
)
image_prompt.upload(
preprocess_image_fn,
inputs=[image_prompt],
outputs=[image_prompt, raw_image_cache],
)
image_prompt.change(
lambda: tuple(
[
gr.Button(interactive=False),
gr.Button(interactive=False),
gr.Button(interactive=False),
None,
"",
None,
None,
"",
"",
"",
"",
"",
"",
"",
"",
]
),
outputs=[
extract_rep3d_btn,
extract_urdf_btn,
download_urdf,
model_output_gs,
aligned_gs,
model_output_mesh,
video_output,
asset_cat_text,
height_range_text,
mass_range_text,
asset_version_text,
est_type_text,
est_height_text,
est_mass_text,
est_mu_text,
],
)
image_prompt.change(
active_btn_by_content,
inputs=image_prompt,
outputs=generate_btn,
)
image_prompt_sam.upload(
preprocess_sam_image_fn,
inputs=[image_prompt_sam],
outputs=[image_prompt_sam, raw_image_cache],
)
image_prompt_sam.change(
lambda: tuple(
[
gr.Button(interactive=False),
gr.Button(interactive=False),
gr.Button(interactive=False),
None,
None,
None,
"",
"",
"",
"",
"",
"",
"",
"",
None,
[],
]
),
outputs=[
extract_rep3d_btn,
extract_urdf_btn,
download_urdf,
model_output_gs,
model_output_mesh,
video_output,
asset_cat_text,
height_range_text,
mass_range_text,
asset_version_text,
est_type_text,
est_height_text,
est_mass_text,
est_mu_text,
image_mask_sam,
selected_points,
],
)
image_prompt_sam.select(
select_point,
[
image_prompt_sam,
selected_points,
fg_bg_radio,
],
[image_mask_sam, image_seg_sam],
)
image_seg_sam.change(
active_btn_by_content,
inputs=image_seg_sam,
outputs=generate_btn,
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).success(
image_to_3d,
inputs=[
image_prompt,
seed,
ss_guidance_strength,
ss_sampling_steps,
slat_guidance_strength,
slat_sampling_steps,
raw_image_cache,
image_seg_sam,
is_samimage,
],
outputs=[output_buf, video_output],
).success(
lambda: gr.Button(interactive=True),
outputs=[extract_rep3d_btn],
)
extract_rep3d_btn.click(
extract_3d_representations_v2,
inputs=[
output_buf,
project_delight,
],
outputs=[
model_output_mesh,
model_output_gs,
model_output_obj,
aligned_gs,
],
).success(
lambda: gr.Button(interactive=True),
outputs=[extract_urdf_btn],
)
extract_urdf_btn.click(
extract_urdf,
inputs=[
aligned_gs,
model_output_obj,
asset_cat_text,
height_range_text,
mass_range_text,
asset_version_text,
],
outputs=[
download_urdf,
est_type_text,
est_height_text,
est_mass_text,
est_mu_text,
],
queue=True,
show_progress="full",
).success(
lambda: gr.Button(interactive=True),
outputs=[download_urdf],
)
if __name__ == "__main__":
demo.launch()