import gradio as gr import spaces import os import shutil os.environ['SPCONV_ALGO'] = 'native' from typing import * import torch import numpy as np import imageio from easydict import EasyDict as edict from PIL import Image from Amodal3R.pipelines import Amodal3RImageTo3DPipeline from Amodal3R.representations import Gaussian, MeshExtractResult from Amodal3R.utils import render_utils, postprocessing_utils from segment_anything import sam_model_registry, SamPredictor from huggingface_hub import hf_hub_download import cv2 MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') os.makedirs(TMP_DIR, exist_ok=True) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' def start_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) def end_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) shutil.rmtree(user_dir) def change_message(): return "Please wait for a few seconds after uploading the image." def reset_image(predictor, img): img = np.array(img) predictor.set_image(img) original_img = img.copy() return predictor, original_img, "The models are ready.", [], [], [], original_img def button_clickable(selected_points): if len(selected_points) > 0: return gr.Button.update(interactive=True) else: return gr.Button.update(interactive=False) def run_sam(img, predictor, selected_points): if len(selected_points) == 0: return np.zeros(img.shape[:2], dtype=np.uint8) input_points = [p for p in selected_points] input_labels = [1 for _ in range(len(selected_points))] masks, _, _ = predictor.predict( point_coords=np.array(input_points), point_labels=np.array(input_labels), multimask_output=False, ) best_mask = masks[0].astype(np.uint8) # dilate if len(selected_points) > 1: kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) best_mask = cv2.dilate(best_mask, kernel, iterations=1) best_mask = cv2.erode(best_mask, kernel, iterations=1) return best_mask @spaces.GPU def image_to_3d( image: np.ndarray, mask: np.ndarray, seed: int, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int, erode_kernel_size: int, req: gr.Request, ) -> Tuple[dict, str]: user_dir = os.path.join(TMP_DIR, str(req.session_hash)) outputs = pipeline.run_multi_image( [image], [mask], seed=seed, formats=["gaussian", "mesh"], sparse_structure_sampler_params={ "steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength, }, slat_sampler_params={ "steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength, }, mode="stochastic", erode_kernel_size=erode_kernel_size, ) video = render_utils.render_video(outputs['gaussian'][0], num_frames=120, bg_color=(1,1,1))['color'] video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] video_path = os.path.join(user_dir, 'sample.mp4') imageio.mimsave(video_path, video, fps=15) state = pack_state(outputs['gaussian'][0], outputs['mesh'][0]) torch.cuda.empty_cache() return state, video_path @spaces.GPU(duration=90) def extract_glb( state: dict, mesh_simplify: float, texture_size: int, req: gr.Request, ) -> tuple: user_dir = os.path.join(TMP_DIR, str(req.session_hash)) gs, mesh = unpack_state(state) glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) glb_path = os.path.join(user_dir, 'sample.glb') glb.export(glb_path) torch.cuda.empty_cache() return glb_path, glb_path @spaces.GPU def extract_gaussian(state: dict, req: gr.Request) -> tuple: user_dir = os.path.join(TMP_DIR, str(req.session_hash)) gs, _ = unpack_state(state) gaussian_path = os.path.join(user_dir, 'sample.ply') gs.save_ply(gaussian_path) torch.cuda.empty_cache() return gaussian_path, gaussian_path def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: return { 'gaussian': { **gs.init_params, '_xyz': gs._xyz.cpu().numpy(), '_features_dc': gs._features_dc.cpu().numpy(), '_scaling': gs._scaling.cpu().numpy(), '_rotation': gs._rotation.cpu().numpy(), '_opacity': gs._opacity.cpu().numpy(), }, 'mesh': { 'vertices': mesh.vertices.cpu().numpy(), 'faces': mesh.faces.cpu().numpy(), }, } def unpack_state(state: dict) -> tuple: gs = Gaussian( aabb=state['gaussian']['aabb'], sh_degree=state['gaussian']['sh_degree'], mininum_kernel_size=state['gaussian']['mininum_kernel_size'], scaling_bias=state['gaussian']['scaling_bias'], opacity_bias=state['gaussian']['opacity_bias'], scaling_activation=state['gaussian']['scaling_activation'], ) gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda') gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda') gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda') gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda') gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda') mesh = edict( vertices=torch.tensor(state['mesh']['vertices'], device='cuda'), faces=torch.tensor(state['mesh']['faces'], device='cuda'), ) return gs, mesh def get_sam_predictor(): sam_checkpoint = hf_hub_download("ybelkada/segment-anything", "checkpoints/sam_vit_h_4b8939.pth") model_type = "vit_h" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam_predictor = SamPredictor(sam) return sam_predictor def draw_points_on_image(image, point): image_with_points = image.copy() x, y = point color = (255, 0, 0) cv2.circle(image_with_points, (int(x), int(y)), radius=10, color=color, thickness=-1) return image_with_points def see_point(image, x, y): updated_image = draw_points_on_image(image, [x,y]) return updated_image def add_point(x, y, visible_points): if [x, y] not in visible_points: visible_points.append([x, y]) return visible_points def delete_point(visible_points): visible_points.pop() return visible_points def clear_all_points(image): updated_image = image.copy() return updated_image def see_visible_points(image, visible_points): updated_image = image.copy() for p in visible_points: cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(255, 0, 0), thickness=-1) return updated_image def see_occlusion_points(image, occlusion_points): updated_image = image.copy() for p in occlusion_points: cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(0, 255, 0), thickness=-1) return updated_image def update_all_points(points): text = f"Points: {points}" dropdown_choices = [f"({p[0]}, {p[1]})" for p in points] return text, gr.Dropdown(show_label=False, choices=dropdown_choices, value=None, interactive=True) def delete_selected(image, visible_points, occlusion_points, occlusion_mask_list, selected_value, point_type): if point_type == "visibility": try: selected_index = [f"({p[0]}, {p[1]})" for p in visible_points].index(selected_value) except ValueError: selected_index = None if selected_index is not None and 0 <= selected_index < len(visible_points): visible_points.pop(selected_index) else: try: selected_index = [f"({p[0]}, {p[1]})" for p in occlusion_points].index(selected_value) except ValueError: selected_index = None if selected_index is not None and 0 <= selected_index < len(occlusion_points): occlusion_points.pop(selected_index) occlusion_mask_list.pop(selected_index) updated_image = image.copy() updated_image = see_visible_points(updated_image, visible_points) updated_image = see_occlusion_points(updated_image, occlusion_points) if point_type == "visibility": updated_text, dropdown = update_all_points(visible_points) else: updated_text, dropdown = update_all_points(occlusion_points) return updated_image, visible_points, occlusion_points, updated_text, dropdown def add_current_mask(visibility_mask, visibilty_mask_list, point_type): if point_type == "visibility": if len(visibilty_mask_list) > 0: if np.array_equal(visibility_mask, visibilty_mask_list[-1]): return visibilty_mask_list visibilty_mask_list.append(visibility_mask) return visibilty_mask_list else: # the occlusion mask will be automatically added, so do nothing here return visibilty_mask_list def apply_mask_overlay(image, mask, color=(255, 0, 0)): img_arr = image overlay = img_arr.copy() gray_color = np.array([200, 200, 200], dtype=np.uint8) non_mask = mask == 0 overlay[non_mask] = (0.5 * overlay[non_mask] + 0.5 * gray_color).astype(np.uint8) contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(overlay, contours, -1, color, 2) return overlay def vis_mask(image, mask_list): updated_image = image.copy() combined_mask = np.zeros_like(updated_image[:, :, 0]) for mask in mask_list: combined_mask = cv2.bitwise_or(combined_mask, mask) updated_image = apply_mask_overlay(updated_image, combined_mask) return updated_image def segment_and_overlay(image, points, sam_predictor, mask_list, point_type): if point_type == "visibility": visible_mask = run_sam(image, sam_predictor, points) for mask in mask_list: visible_mask = cv2.bitwise_or(visible_mask, mask) overlaid = apply_mask_overlay(image, visible_mask * 255) return overlaid, visible_mask, mask_list else: combined_occlusion_mask = np.zeros_like(image[:, :, 0]) mask_list = [] if len(points) != 0: for point in points: mask = run_sam(image, sam_predictor, [point]) mask_list.append(mask) combined_occlusion_mask = cv2.bitwise_or(combined_occlusion_mask, mask) overlaid = apply_mask_overlay(image, combined_occlusion_mask * 255, color=(0, 255, 0)) return overlaid, combined_occlusion_mask, mask_list def delete_mask(visibility_mask_list, occlusion_mask_list, occlusion_points_state, point_type): if point_type == "visibility": if len(visibility_mask_list) > 0: visibility_mask_list.pop() else: if len(occlusion_mask_list) > 0: occlusion_mask_list.pop() occlusion_points_state.pop() return visibility_mask_list, occlusion_mask_list, occlusion_points_state def check_combined_mask(image, visibility_mask, visibility_mask_list, occlusion_mask_list, scale=0.68): if visibility_mask.sum() == 0: return np.zeros_like(image), np.zeros_like(image[:, :, 0]) updated_image = image.copy() combined_mask = np.zeros_like(updated_image[:, :, 0]) occluded_mask = np.zeros_like(updated_image[:, :, 0]) binary_visibility_masks = [(m > 0).astype(np.uint8) for m in visibility_mask_list] combined_mask = np.zeros_like(binary_visibility_masks[0]) if binary_visibility_masks else (visibility_mask > 0).astype(np.uint8) for m in binary_visibility_masks: combined_mask = cv2.bitwise_or(combined_mask, m) if len(binary_visibility_masks) > 1: kernel = np.ones((5, 5), np.uint8) combined_mask = cv2.dilate(combined_mask, kernel, iterations=1) binary_occlusion_masks = [(m > 0).astype(np.uint8) for m in occlusion_mask_list] occluded_mask = np.zeros_like(binary_occlusion_masks[0]) if binary_occlusion_masks else np.zeros_like(combined_mask) for m in binary_occlusion_masks: occluded_mask = cv2.bitwise_or(occluded_mask, m) kernel_small = np.ones((3, 3), np.uint8) if len(binary_occlusion_masks) > 0: dilated = cv2.dilate(combined_mask, kernel_small, iterations=1) boundary_mask = dilated - combined_mask occluded_mask = cv2.bitwise_or(occluded_mask, boundary_mask) occluded_mask = (occluded_mask > 0).astype(np.uint8) occluded_mask = cv2.dilate(occluded_mask, kernel_small, iterations=1) occluded_mask = (occluded_mask > 0).astype(np.uint8) else: occluded_mask = 1 - combined_mask combined_mask[occluded_mask == 1] = 0 occluded_mask = (1-occluded_mask) * 255 masked_img = updated_image * combined_mask[:, :, None] occluded_mask[combined_mask == 1] = 127 x, y, w, h = cv2.boundingRect(combined_mask.astype(np.uint8)) ori_h, ori_w = masked_img.shape[:2] target_size = 512 scale_factor = target_size / max(w, h) final_scale = scale_factor * scale new_w = int(round(ori_w * final_scale)) new_h = int(round(ori_h * final_scale)) resized_occluded_mask = cv2.resize(occluded_mask.astype(np.uint8), (new_w, new_h), interpolation=cv2.INTER_NEAREST) resized_img = cv2.resize(masked_img, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4) final_img = np.zeros((target_size, target_size, 3), dtype=updated_image.dtype) final_occluded_mask = np.ones((target_size, target_size), dtype=np.uint8) * 255 new_x = int(round(x * final_scale)) new_y = int(round(y * final_scale)) new_w_box = int(round(w * final_scale)) new_h_box = int(round(h * final_scale)) new_cx = new_x + new_w_box // 2 new_cy = new_y + new_h_box // 2 final_cx, final_cy = target_size // 2, target_size // 2 x_offset = final_cx - new_cx y_offset = final_cy - new_cy final_x_start = max(0, x_offset) final_y_start = max(0, y_offset) final_x_end = min(target_size, x_offset + new_w) final_y_end = min(target_size, y_offset + new_h) img_x_start = max(0, -x_offset) img_y_start = max(0, -y_offset) img_x_end = min(new_w, target_size - x_offset) img_y_end = min(new_h, target_size - y_offset) final_img[final_y_start:final_y_end, final_x_start:final_x_end] = resized_img[img_y_start:img_y_end, img_x_start:img_x_end] final_occluded_mask[final_y_start:final_y_end, final_x_start:final_x_end] = resized_occluded_mask[img_y_start:img_y_end, img_x_start:img_x_end] return final_img, final_occluded_mask def get_point(img, point_type, visible_points_state, occlusion_points_state, evt: gr.SelectData): updated_img = np.array(img).copy() if point_type == "visibility": visible_points_state = add_point(evt.index[0], evt.index[1], visible_points_state) else: occlusion_points_state = add_point(evt.index[0], evt.index[1], occlusion_points_state) updated_img = see_visible_points(updated_img, visible_points_state) updated_img = see_occlusion_points(updated_img, occlusion_points_state) return updated_img, visible_points_state, occlusion_points_state def change_point_type(point_type, visible_points_state, occlusion_points_state): if point_type == "visibility": text = f"Points: {visible_points_state}" dropdown_choices = [f"({p[0]}, {p[1]})" for p in visible_points_state] else: text = f"Points: {occlusion_points_state}" dropdown_choices = [f"({p[0]}, {p[1]})" for p in occlusion_points_state] return text, gr.Dropdown(show_label=False, choices=dropdown_choices, value=None, interactive=True) def get_seed(randomize_seed: bool, seed: int) -> int: """ Get the random seed. """ return np.random.randint(0, MAX_SEED) if randomize_seed else seed with gr.Blocks(delete_cache=(600, 600)) as demo: gr.Markdown(""" ## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/) """) predictor = gr.State(value=get_sam_predictor()) visible_points_state = gr.State(value=[]) occlusion_points_state = gr.State(value=[]) occlusion_mask = gr.State(value=None) occlusion_mask_list = gr.State(value=[]) original_image = gr.State(value=None) visibility_mask = gr.State(value=None) visibility_mask_list = gr.State(value=[]) occluded_mask = gr.State(value=None) output_buf = gr.State() with gr.Row(): with gr.Column(): gr.Markdown(""" ### Step 1 - Generate Visibility and Occlusion Mask. * Please click "Load Example Image" when using the provided example images (bottom). * Please wait for a few seconds after uploading the image. Segment Anything is getting ready. * **Click to add the point prompts** to indicate the target object (multiple points supported) and occluders (one point for an occluder for better usability). * "Add mask", current mask will be saved if the input needs to be added sequentially. * The scale of target object can be adjusted for better reconstruction, we suggest 0.4 to 0.7 for most cases. """) with gr.Row(): input_image = gr.Image(interactive=True, type='pil', label='Input Occlusion Image', show_label=True, sources="upload", height=300) input_with_prompt = gr.Image(type="numpy", label='Input with Prompt', interactive=False, height=300) with gr.Row(): apply_example_btn = gr.Button("Load Example Image") message = gr.Markdown("Please wait a few seconds after uploading the image.", label="Message") with gr.Row(): point_type = gr.Radio(["visibility", "occlusion"], label="Point Prompt Type", value="visibility") with gr.Row(): with gr.Column(): points_text = gr.Textbox(show_label=False, interactive=False) with gr.Column(): points_dropdown = gr.Dropdown(show_label=False, choices=[], value=None, interactive=True) delete_button = gr.Button("Delete Selected Point") with gr.Row(): with gr.Column(): render_mask = gr.Image(label='Render Mask', interactive=False, height=300) with gr.Row(): add_mask = gr.Button("Add Mask") undo_mask = gr.Button("Undo Last Mask") with gr.Column(): vis_input = gr.Image(label='Visible Input', interactive=False, height=300) with gr.Row(): zoom_scale = gr.Slider(0.3, 1.0, label="Target Object Scale", value=0.68, step=0.1) with gr.Row(): check_visible_input = gr.Button("Generate Occluded Input") with gr.Column(): gr.Markdown(""" ### Step 2 - 3D Amodal Reconstruction. (Thanks to [TRELLIS](https://huggingface.co/spaces/JeffreyXiang/TRELLIS) for the 3D rendering component!) * Different random seeds can be tried in "Generation Settings", if you think the results are not ideal. * The boundary of the segmentation may not be accurate, so here we provide the option to erode the visible area (try 0, 3 or 5). * If the reconstructed 3D asset is satisfactory, interactive GLB file can be extracted (may look dull due to the absence of light source) and downloaded. """) with gr.Row(): video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) with gr.Row(): with gr.Accordion(label="Generation Settings", open=False): with gr.Row(): with gr.Column(): seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1) randomize_seed = gr.Checkbox(label="Randomize Seed", value=False) with gr.Column(): erode_kernel_size = gr.Slider(0, 5, label="Erode Kernel Size", value=3, step=1) gr.Markdown("Stage 1: Sparse Structure Generation") with gr.Row(): ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) gr.Markdown("Stage 2: Structured Latent Generation") with gr.Row(): slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1) slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) with gr.Row(): generate_btn = gr.Button("Amodal 3D Reconstruction") with gr.Row(): model_output = gr.Model3D(label="Extracted GLB", pan_speed=0.5, height=300, clear_color=(0.9,0.9,0.9,1)) with gr.Row(): with gr.Accordion(label="GLB Extraction Settings", open=False): mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01) texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512) with gr.Row(): extract_glb_btn = gr.Button("Extract GLB") download_glb = gr.DownloadButton(label="Download GLB", interactive=False) with gr.Row(): examples = gr.Examples( examples=[ f'assets/example_image/{image}' for image in os.listdir("assets/example_image") ], inputs=[input_image], fn=lambda x: x, outputs=[input_image], run_on_click=True, examples_per_page=12, ) # # Handlers demo.load(start_session) demo.unload(end_session) input_image.upload( change_message, [], [message] ).then( reset_image, [predictor, input_image], [predictor, original_image, message, visible_points_state, occlusion_points_state, occlusion_mask_list, input_with_prompt], ) apply_example_btn.click( change_message, [], [message] ).then( reset_image, inputs=[predictor, input_image], outputs=[predictor, original_image, message, visible_points_state, occlusion_points_state, occlusion_mask_list, input_with_prompt] ) input_image.select( get_point, inputs=[input_image, point_type, visible_points_state, occlusion_points_state], outputs=[input_with_prompt, visible_points_state, occlusion_points_state] ) point_type.change( change_point_type, inputs=[point_type, visible_points_state, occlusion_points_state], outputs=[points_text, points_dropdown] ) visible_points_state.change( update_all_points, inputs=[visible_points_state], outputs=[points_text, points_dropdown] ).then( segment_and_overlay, inputs=[original_image, visible_points_state, predictor, visibility_mask_list, point_type], outputs=[render_mask, visibility_mask, visibility_mask_list] ).then( check_combined_mask, inputs=[original_image, visibility_mask, visibility_mask_list, occlusion_mask_list, zoom_scale], outputs=[vis_input, occluded_mask] ) occlusion_points_state.change( update_all_points, inputs=[occlusion_points_state], outputs=[points_text, points_dropdown] ).then( segment_and_overlay, inputs=[original_image, occlusion_points_state, predictor, occlusion_mask_list, point_type], outputs=[render_mask, occlusion_mask, occlusion_mask_list] ).then( check_combined_mask, inputs=[original_image, visibility_mask, visibility_mask_list, occlusion_mask_list, zoom_scale], outputs=[vis_input, occluded_mask] ) delete_button.click( delete_selected, inputs=[original_image, visible_points_state, occlusion_points_state, occlusion_mask_list, points_dropdown, point_type], outputs=[input_with_prompt, visible_points_state, occlusion_points_state, points_text, points_dropdown] ) add_mask.click( add_current_mask, inputs=[visibility_mask, visibility_mask_list, point_type], outputs=[visibility_mask_list] ) undo_mask.click( delete_mask, inputs=[visibility_mask_list, occlusion_mask_list, occlusion_points_state, point_type], outputs=[visibility_mask_list, occlusion_mask_list, occlusion_points_state] ) check_visible_input.click( check_combined_mask, inputs=[original_image, visibility_mask, visibility_mask_list, occlusion_mask_list, zoom_scale], outputs=[vis_input, occluded_mask] ) # 3D Amodal Reconstruction generate_btn.click( get_seed, inputs=[randomize_seed, seed], outputs=[seed], ).then( image_to_3d, inputs=[vis_input, occluded_mask, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, erode_kernel_size], outputs=[output_buf, video_output], ) extract_glb_btn.click( extract_glb, inputs=[output_buf, mesh_simplify, texture_size], outputs=[model_output, download_glb], ).then( lambda: gr.Button(interactive=True), outputs=[download_glb], ) model_output.clear( lambda: gr.Button(interactive=False), outputs=[download_glb], ) if __name__ == "__main__": pipeline = Amodal3RImageTo3DPipeline.from_pretrained("Sm0kyWu/Amodal3R") pipeline.cuda() try: pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) except: pass demo.launch()