Spaces:
Running
on
Zero
Running
on
Zero
import os | |
os.environ["GRADIO_APP"] = "imageto3d" | |
import gradio as gr | |
from common import ( | |
MAX_SEED, | |
VERSION, | |
active_btn_by_content, | |
end_session, | |
extract_3d_representations_v2, | |
extract_urdf, | |
get_seed, | |
image_css, | |
image_to_3d, | |
lighting_css, | |
preprocess_image_fn, | |
preprocess_sam_image_fn, | |
select_point, | |
start_session, | |
) | |
from gradio.themes import Default | |
from gradio.themes.utils.colors import slate | |
with gr.Blocks( | |
delete_cache=(43200, 43200), theme=Default(primary_hue=slate) | |
) as demo: | |
gr.Markdown( | |
f""" | |
## Image to 3D Asset Pipeline \n | |
version: {VERSION} \n | |
<!-- The service is temporarily deployed on `dev015-10.34.8.82: CUDA 4`. --> | |
""" | |
) | |
gr.HTML(image_css) | |
gr.HTML(lighting_css) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Tabs() as input_tabs: | |
with gr.Tab( | |
label="Image(auto seg)", id=0 | |
) as single_image_input_tab: | |
raw_image_cache = gr.Image( | |
format="png", | |
image_mode="RGB", | |
type="pil", | |
visible=False, | |
) | |
image_prompt = gr.Image( | |
label="Input Image", | |
format="png", | |
image_mode="RGBA", | |
type="pil", | |
height=400, | |
elem_classes=["image_fit"], | |
) | |
gr.Markdown( | |
""" | |
If you are not satisfied with the auto segmentation | |
result, please switch to the `Image(SAM seg)` tab.""" | |
) | |
with gr.Tab( | |
label="Image(SAM seg)", id=1 | |
) as samimage_input_tab: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_prompt_sam = gr.Image( | |
label="Input Image", | |
type="numpy", | |
height=400, | |
elem_classes=["image_fit"], | |
) | |
image_seg_sam = gr.Image( | |
label="SAM Seg Image", | |
image_mode="RGBA", | |
type="pil", | |
height=400, | |
visible=False, | |
) | |
with gr.Column(scale=1): | |
image_mask_sam = gr.AnnotatedImage( | |
elem_classes=["image_fit"] | |
) | |
fg_bg_radio = gr.Radio( | |
["foreground_point", "background_point"], | |
label="Select foreground(green) or background(red) points, by default foreground", # noqa | |
value="foreground_point", | |
) | |
gr.Markdown( | |
""" Click the `Input Image` to select SAM points, | |
after get the satisified segmentation, click `Generate` | |
button to generate the 3D asset. \n | |
Note: If the segmented foreground is too small relative | |
to the entire image area, the generation will fail. | |
""" | |
) | |
with gr.Accordion(label="Generation Settings", open=False): | |
with gr.Row(): | |
seed = gr.Slider( | |
0, MAX_SEED, label="Seed", value=0, step=1 | |
) | |
with gr.Row(): | |
randomize_seed = gr.Checkbox( | |
label="Randomize Seed", value=False | |
) | |
project_delight = gr.Checkbox( | |
label="Backproject delighting", | |
value=True, | |
) | |
gr.Markdown("Geo Structure Generation") | |
with gr.Row(): | |
ss_guidance_strength = gr.Slider( | |
0.0, | |
10.0, | |
label="Guidance Strength", | |
value=7.5, | |
step=0.1, | |
) | |
ss_sampling_steps = gr.Slider( | |
1, 50, label="Sampling Steps", value=12, step=1 | |
) | |
gr.Markdown("Visual Appearance Generation") | |
with gr.Row(): | |
slat_guidance_strength = gr.Slider( | |
0.0, | |
10.0, | |
label="Guidance Strength", | |
value=3.0, | |
step=0.1, | |
) | |
slat_sampling_steps = gr.Slider( | |
1, 50, label="Sampling Steps", value=12, step=1 | |
) | |
generate_btn = gr.Button( | |
"Generate(~0.5 mins)", variant="primary", interactive=False | |
) | |
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False) | |
with gr.Row(): | |
extract_rep3d_btn = gr.Button( | |
"Extract 3D Representation(~2 mins)", | |
variant="primary", | |
interactive=False, | |
) | |
with gr.Accordion( | |
label="Enter Asset Attributes(optional)", open=False | |
): | |
asset_cat_text = gr.Textbox( | |
label="Enter Asset Category (e.g., chair)" | |
) | |
height_range_text = gr.Textbox( | |
label="Enter Height Range in meter (e.g., 0.5-0.6)" | |
) | |
mass_range_text = gr.Textbox( | |
label="Enter Mass Range in kg (e.g., 1.1-1.2)" | |
) | |
asset_version_text = gr.Textbox( | |
label=f"Enter version (e.g., {VERSION})" | |
) | |
with gr.Row(): | |
extract_urdf_btn = gr.Button( | |
"Extract URDF with physics(~1 mins)", | |
variant="primary", | |
interactive=False, | |
) | |
with gr.Row(): | |
gr.Markdown( | |
"#### Estimated Asset 3D Attributes(No input required)" | |
) | |
with gr.Row(): | |
est_type_text = gr.Textbox( | |
label="Asset category", interactive=False | |
) | |
est_height_text = gr.Textbox( | |
label="Real height(.m)", interactive=False | |
) | |
est_mass_text = gr.Textbox( | |
label="Mass(.kg)", interactive=False | |
) | |
est_mu_text = gr.Textbox( | |
label="Friction coefficient", interactive=False | |
) | |
with gr.Row(): | |
download_urdf = gr.DownloadButton( | |
label="Download URDF", variant="primary", interactive=False | |
) | |
gr.Markdown( | |
""" NOTE: If `Asset Attributes` are provided, the provided | |
properties will be used; otherwise, the GPT-preset properties | |
will be applied. \n | |
The `Download URDF` file is restored to the real scale and | |
has quality inspection, open with an editor to view details. | |
""" | |
) | |
with gr.Row() as single_image_example: | |
examples = gr.Examples( | |
label="Image Gallery", | |
examples=[ | |
[f"assets/example_image/{image}"] | |
for image in os.listdir( | |
"assets/example_image" | |
) | |
], | |
inputs=[image_prompt], | |
fn=preprocess_image_fn, | |
outputs=[image_prompt, raw_image_cache], | |
run_on_click=True, | |
examples_per_page=10, | |
) | |
with gr.Row(visible=False) as single_sam_image_example: | |
examples = gr.Examples( | |
label="Image Gallery", | |
examples=[ | |
f"assets/example_image/{image}" | |
for image in os.listdir( | |
"assets/example_image" | |
) | |
], | |
inputs=[image_prompt_sam], | |
fn=preprocess_sam_image_fn, | |
outputs=[image_prompt_sam, raw_image_cache], | |
run_on_click=True, | |
examples_per_page=10, | |
) | |
with gr.Column(scale=1): | |
video_output = gr.Video( | |
label="Generated 3D Asset", | |
autoplay=True, | |
loop=True, | |
height=300, | |
) | |
model_output_gs = gr.Model3D( | |
label="Gaussian Representation", height=300, interactive=False | |
) | |
aligned_gs = gr.Textbox(visible=False) | |
gr.Markdown( | |
""" The rendering of `Gaussian Representation` takes additional 10s. """ # noqa | |
) | |
with gr.Row(): | |
model_output_mesh = gr.Model3D( | |
label="Mesh Representation", | |
height=300, | |
interactive=False, | |
clear_color=[0.8, 0.8, 0.8, 1], | |
elem_id="lighter_mesh", | |
) | |
is_samimage = gr.State(False) | |
output_buf = gr.State() | |
selected_points = gr.State(value=[]) | |
demo.load(start_session) | |
demo.unload(end_session) | |
single_image_input_tab.select( | |
lambda: tuple( | |
[False, gr.Row.update(visible=True), gr.Row.update(visible=False)] | |
), | |
outputs=[is_samimage, single_image_example, single_sam_image_example], | |
) | |
samimage_input_tab.select( | |
lambda: tuple( | |
[True, gr.Row.update(visible=True), gr.Row.update(visible=False)] | |
), | |
outputs=[is_samimage, single_sam_image_example, single_image_example], | |
) | |
image_prompt.upload( | |
preprocess_image_fn, | |
inputs=[image_prompt], | |
outputs=[image_prompt, raw_image_cache], | |
) | |
image_prompt.change( | |
lambda: tuple( | |
[ | |
gr.Button(interactive=False), | |
gr.Button(interactive=False), | |
gr.Button(interactive=False), | |
None, | |
"", | |
None, | |
None, | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
] | |
), | |
outputs=[ | |
extract_rep3d_btn, | |
extract_urdf_btn, | |
download_urdf, | |
model_output_gs, | |
aligned_gs, | |
model_output_mesh, | |
video_output, | |
asset_cat_text, | |
height_range_text, | |
mass_range_text, | |
asset_version_text, | |
est_type_text, | |
est_height_text, | |
est_mass_text, | |
est_mu_text, | |
], | |
) | |
image_prompt.change( | |
active_btn_by_content, | |
inputs=image_prompt, | |
outputs=generate_btn, | |
) | |
image_prompt_sam.upload( | |
preprocess_sam_image_fn, | |
inputs=[image_prompt_sam], | |
outputs=[image_prompt_sam, raw_image_cache], | |
) | |
image_prompt_sam.change( | |
lambda: tuple( | |
[ | |
gr.Button(interactive=False), | |
gr.Button(interactive=False), | |
gr.Button(interactive=False), | |
None, | |
None, | |
None, | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
None, | |
[], | |
] | |
), | |
outputs=[ | |
extract_rep3d_btn, | |
extract_urdf_btn, | |
download_urdf, | |
model_output_gs, | |
model_output_mesh, | |
video_output, | |
asset_cat_text, | |
height_range_text, | |
mass_range_text, | |
asset_version_text, | |
est_type_text, | |
est_height_text, | |
est_mass_text, | |
est_mu_text, | |
image_mask_sam, | |
selected_points, | |
], | |
) | |
image_prompt_sam.select( | |
select_point, | |
[ | |
image_prompt_sam, | |
selected_points, | |
fg_bg_radio, | |
], | |
[image_mask_sam, image_seg_sam], | |
) | |
image_seg_sam.change( | |
active_btn_by_content, | |
inputs=image_seg_sam, | |
outputs=generate_btn, | |
) | |
generate_btn.click( | |
get_seed, | |
inputs=[randomize_seed, seed], | |
outputs=[seed], | |
).success( | |
image_to_3d, | |
inputs=[ | |
image_prompt, | |
seed, | |
ss_guidance_strength, | |
ss_sampling_steps, | |
slat_guidance_strength, | |
slat_sampling_steps, | |
raw_image_cache, | |
image_seg_sam, | |
is_samimage, | |
], | |
outputs=[output_buf, video_output], | |
).success( | |
lambda: gr.Button(interactive=True), | |
outputs=[extract_rep3d_btn], | |
) | |
extract_rep3d_btn.click( | |
extract_3d_representations_v2, | |
inputs=[ | |
output_buf, | |
project_delight, | |
], | |
outputs=[ | |
model_output_mesh, | |
model_output_gs, | |
model_output_obj, | |
aligned_gs, | |
], | |
).success( | |
lambda: gr.Button(interactive=True), | |
outputs=[extract_urdf_btn], | |
) | |
extract_urdf_btn.click( | |
extract_urdf, | |
inputs=[ | |
aligned_gs, | |
model_output_obj, | |
asset_cat_text, | |
height_range_text, | |
mass_range_text, | |
asset_version_text, | |
], | |
outputs=[ | |
download_urdf, | |
est_type_text, | |
est_height_text, | |
est_mass_text, | |
est_mu_text, | |
], | |
queue=True, | |
show_progress="full", | |
).success( | |
lambda: gr.Button(interactive=True), | |
outputs=[download_urdf], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |