Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import os | |
import random | |
import tempfile | |
from typing import Any, List, Union | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
import trimesh | |
from gradio_image_prompter import ImagePrompter | |
from gradio_litmodel3d import LitModel3D | |
from huggingface_hub import snapshot_download | |
from PIL import Image | |
from skimage import measure | |
from transformers import AutoModelForMaskGeneration, AutoProcessor | |
from midi.pipelines.pipeline_midi import MIDIPipeline | |
from midi.utils.smoothing import smooth_gpu | |
from scripts.grounding_sam import plot_segmentation, segment | |
from scripts.inference_midi import preprocess_image, split_rgb_mask | |
# Constants | |
MAX_SEED = np.iinfo(np.int32).max | |
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") | |
DTYPE = torch.bfloat16 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
REPO_ID = "VAST-AI/MIDI-3D" | |
MARKDOWN = """ | |
## Image to 3D Scene with [MIDI-3D](https://huanngzh.github.io/MIDI-Page/) | |
<b>Important!</b> Please check out our [instruction video](https://github.com/user-attachments/assets/814c046e-f5c3-47cf-bb56-60154be8374c)! | |
1. Upload an image, and draw bounding boxes for each instance by holding and dragging the mouse. Then clik "Run Segmentation" to generate the segmentation result. <b>Ensure instances should not be too small and bounding boxes fit snugly around each instance.</b> | |
2. <b>Check "Do image padding" in "Generation Settings" if instances in your image are too close to the image border.</b> Then click "Run Generation" to generate a 3D scene from the image and segmentation result. | |
3. If you find the generated 3D scene satisfactory, download it by clicking the "Download GLB" button. | |
""" | |
EXAMPLES = [ | |
[ | |
{ | |
"image": "assets/example_data/Cartoon-Style/03_rgb.png", | |
}, | |
"assets/example_data/Cartoon-Style/03_seg.png", | |
42, | |
False, | |
False, | |
], | |
[ | |
{ | |
"image": "assets/example_data/Cartoon-Style/01_rgb.png", | |
}, | |
"assets/example_data/Cartoon-Style/01_seg.png", | |
42, | |
False, | |
False, | |
], | |
[ | |
{ | |
"image": "assets/example_data/Realistic-Style/02_rgb.png", | |
}, | |
"assets/example_data/Realistic-Style/02_seg.png", | |
42, | |
False, | |
False, | |
], | |
[ | |
{ | |
"image": "assets/example_data/Cartoon-Style/00_rgb.png", | |
}, | |
"assets/example_data/Cartoon-Style/00_seg.png", | |
42, | |
False, | |
False, | |
], | |
[ | |
{ | |
"image": "assets/example_data/Realistic-Style/00_rgb.png", | |
}, | |
"assets/example_data/Realistic-Style/00_seg.png", | |
42, | |
False, | |
True, | |
], | |
[ | |
{ | |
"image": "assets/example_data/Realistic-Style/01_rgb.png", | |
}, | |
"assets/example_data/Realistic-Style/01_seg.png", | |
42, | |
False, | |
True, | |
], | |
[ | |
{ | |
"image": "assets/example_data/Realistic-Style/05_rgb.png", | |
}, | |
"assets/example_data/Realistic-Style/05_seg.png", | |
42, | |
False, | |
False, | |
], | |
] | |
os.makedirs(TMP_DIR, exist_ok=True) | |
# Prepare models | |
## Grounding SAM | |
segmenter_id = "facebook/sam-vit-base" | |
sam_processor = AutoProcessor.from_pretrained(segmenter_id) | |
sam_segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to( | |
DEVICE, DTYPE | |
) | |
## MIDI-3D | |
local_dir = "pretrained_weights/MIDI-3D" | |
snapshot_download(repo_id=REPO_ID, local_dir=local_dir) | |
pipe: MIDIPipeline = MIDIPipeline.from_pretrained(local_dir).to(DEVICE, DTYPE) | |
pipe.init_custom_adapter( | |
set_self_attn_module_names=[ | |
"blocks.8", | |
"blocks.9", | |
"blocks.10", | |
"blocks.11", | |
"blocks.12", | |
] | |
) | |
# Utils | |
def get_random_hex(): | |
random_bytes = os.urandom(8) | |
random_hex = random_bytes.hex() | |
return random_hex | |
def run_segmentation(image_prompts: Any, polygon_refinement: bool) -> Image.Image: | |
rgb_image = image_prompts["image"].convert("RGB") | |
# pre-process the layers and get the xyxy boxes of each layer | |
if len(image_prompts["points"]) == 0: | |
gr.Error("Please draw bounding boxes for each instance on the image.") | |
boxes = [ | |
[ | |
[int(box[0]), int(box[1]), int(box[3]), int(box[4])] | |
for box in image_prompts["points"] | |
] | |
] | |
# run the segmentation | |
detections = segment( | |
sam_processor, | |
sam_segmentator, | |
rgb_image, | |
boxes=[boxes], | |
polygon_refinement=polygon_refinement, | |
) | |
seg_map_pil = plot_segmentation(rgb_image, detections) | |
torch.cuda.empty_cache() | |
return seg_map_pil | |
def run_midi( | |
pipe: Any, | |
rgb_image: Union[str, Image.Image], | |
seg_image: Union[str, Image.Image], | |
seed: int, | |
num_inference_steps: int = 50, | |
guidance_scale: float = 7.0, | |
do_image_padding: bool = False, | |
) -> trimesh.Scene: | |
if do_image_padding: | |
rgb_image, seg_image = preprocess_image(rgb_image, seg_image) | |
instance_rgbs, instance_masks, scene_rgbs = split_rgb_mask(rgb_image, seg_image) | |
num_instances = len(instance_rgbs) | |
outputs = pipe( | |
image=instance_rgbs, | |
mask=instance_masks, | |
image_scene=scene_rgbs, | |
attention_kwargs={"num_instances": num_instances}, | |
generator=torch.Generator(device=pipe.device).manual_seed(seed), | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
decode_progressive=True, | |
return_dict=False, | |
) | |
return outputs | |
def run_generation( | |
rgb_image: Any, | |
seg_image: Union[str, Image.Image], | |
seed: int, | |
randomize_seed: bool = False, | |
num_inference_steps: int = 50, | |
guidance_scale: float = 7.0, | |
do_image_padding: bool = False, | |
): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
if not isinstance(rgb_image, Image.Image) and "image" in rgb_image: | |
rgb_image = rgb_image["image"] | |
outputs = run_midi( | |
pipe, | |
rgb_image, | |
seg_image, | |
seed, | |
num_inference_steps, | |
guidance_scale, | |
do_image_padding, | |
) | |
# marching cubes | |
trimeshes = [] | |
for _, (logits_, grid_size, bbox_size, bbox_min, bbox_max) in enumerate( | |
zip(*outputs) | |
): | |
grid_logits = logits_.view(grid_size) | |
grid_logits = smooth_gpu(grid_logits, method="gaussian", sigma=1) | |
torch.cuda.empty_cache() | |
vertices, faces, normals, _ = measure.marching_cubes( | |
grid_logits.float().cpu().numpy(), 0, method="lewiner" | |
) | |
vertices = vertices / grid_size * bbox_size + bbox_min | |
# Trimesh | |
mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces)) | |
trimeshes.append(mesh) | |
# compose the output meshes | |
scene = trimesh.Scene(trimeshes) | |
tmp_path = os.path.join(TMP_DIR, f"midi3d_{get_random_hex()}.glb") | |
scene.export(tmp_path) | |
torch.cuda.empty_cache() | |
return tmp_path, tmp_path, seed | |
# Demo | |
with gr.Blocks() as demo: | |
gr.Markdown(MARKDOWN) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
image_prompts = ImagePrompter(label="Input Image", type="pil") | |
seg_image = gr.Image( | |
label="Segmentation Result", type="pil", format="png" | |
) | |
with gr.Accordion("Segmentation Settings", open=False): | |
polygon_refinement = gr.Checkbox( | |
label="Polygon Refinement", value=False | |
) | |
seg_button = gr.Button("Run Segmentation") | |
with gr.Accordion("Generation Settings", open=False): | |
do_image_padding = gr.Checkbox(label="Do image padding", value=False) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=50, | |
) | |
guidance_scale = gr.Slider( | |
label="CFG scale", | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
value=7.0, | |
) | |
gen_button = gr.Button("Run Generation", variant="primary") | |
with gr.Column(): | |
model_output = LitModel3D(label="Generated GLB", exposure=1.0, height=500) | |
download_glb = gr.DownloadButton(label="Download GLB", interactive=False) | |
with gr.Row(): | |
gr.Examples( | |
examples=EXAMPLES, | |
fn=run_generation, | |
inputs=[image_prompts, seg_image, seed, randomize_seed, do_image_padding], | |
outputs=[model_output, download_glb, seed], | |
cache_examples=False, | |
) | |
seg_button.click( | |
run_segmentation, | |
inputs=[ | |
image_prompts, | |
polygon_refinement, | |
], | |
outputs=[seg_image], | |
).then(lambda: gr.Button(interactive=True), outputs=[gen_button]) | |
gen_button.click( | |
run_generation, | |
inputs=[ | |
image_prompts, | |
seg_image, | |
seed, | |
randomize_seed, | |
num_inference_steps, | |
guidance_scale, | |
do_image_padding, | |
], | |
outputs=[model_output, download_glb, seed], | |
).then(lambda: gr.Button(interactive=True), outputs=[download_glb]) | |
demo.launch() | |