import argparse import logging import os import sys from glob import glob import numpy as np import trimesh from PIL import Image from asset3d_gen.data.backproject_v2 import entrypoint as backproject_api from asset3d_gen.models.delight_model import DelightingModel from asset3d_gen.models.gs_model import GaussianOperator from asset3d_gen.models.segment_model import ( BMGG14Remover, RembgRemover, SAMPredictor, trellis_preprocess, ) from asset3d_gen.models.sr_model import ImageRealESRGAN from asset3d_gen.scripts.render_gs import entrypoint as render_gs_api from asset3d_gen.utils.gpt_clients import GPT_CLIENT from asset3d_gen.utils.process_media import ( merge_images_video, render_asset3d, render_mesh, render_video, ) from asset3d_gen.utils.tags import VERSION from asset3d_gen.validators.quality_checkers import ( BaseChecker, ImageAestheticChecker, ImageSegChecker, MeshGeoChecker, ) from asset3d_gen.validators.urdf_convertor import URDFGenerator current_file_path = os.path.abspath(__file__) current_dir = os.path.dirname(current_file_path) sys.path.append(os.path.join(current_dir, "../..")) from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer from thirdparty.TRELLIS.trellis.representations import ( Gaussian, MeshExtractResult, ) from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import ( build_scaling_rotation, inverse_sigmoid, strip_symmetric, ) from thirdparty.TRELLIS.trellis.utils import postprocessing_utils from thirdparty.TRELLIS.trellis.utils.render_utils import ( render_frames, yaw_pitch_r_fov_to_extrinsics_intrinsics, ) logging.basicConfig( format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO ) logger = logging.getLogger(__name__) os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( "~/.cache/torch_extensions" ) os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" os.environ["SPCONV_ALGO"] = "native" DELIGHT = DelightingModel() IMAGESR_MODEL = ImageRealESRGAN(outscale=4) RBG_REMOVER = RembgRemover() RBG14_REMOVER = BMGG14Remover() SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu") PIPELINE = TrellisImageTo3DPipeline.from_pretrained( "jetx/trellis-image-large" ) PIPELINE.cuda() SEG_CHECKER = ImageSegChecker(GPT_CLIENT) GEO_CHECKER = MeshGeoChecker(GPT_CLIENT) AESTHETIC_CHECKER = ImageAestheticChecker() CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER] TMP_DIR = os.path.join( os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d" ) def parse_args(): parser = argparse.ArgumentParser(description="Image to 3D pipeline args.") parser.add_argument( "--image_path", type=str, nargs="+", help="Path to the input images." ) parser.add_argument( "--image_root", type=str, help="Path to the input images folder." ) parser.add_argument( "--output_root", type=str, required=True, help="Root directory for saving outputs.", ) parser.add_argument( "--no_mesh", action="store_true", help="Do not output mesh files." ) parser.add_argument( "--height_range", type=str, default=None, help="The hight in meter to restore the mesh real size.", ) parser.add_argument( "--mass_range", type=str, default=None, help="The mass in kg to restore the mesh real weight.", ) parser.add_argument("--asset_type", type=str, default=None) parser.add_argument("--skip_exists", action="store_true") parser.add_argument("--strict_seg", action="store_true") parser.add_argument("--version", type=str, default=VERSION) args = parser.parse_args() assert ( args.image_path or args.image_root ), "Please provide either --image_path or --image_root." if not args.image_path: args.image_path = glob(os.path.join(args.image_root, "*.png")) args.image_path += glob(os.path.join(args.image_root, "*.jpg")) args.image_path += glob(os.path.join(args.image_root, "*.jpeg")) return args def get_segmented_image( image, sam_remover, rbg_remover, seg_checker, image_path, seg_path, mode="loose", ) -> Image.Image: def _is_valid_seg(img: Image.Image) -> bool: return img.mode == "RGBA" and seg_checker([image_path, seg_path])[0] seg_image = sam_remover(image, save_path=seg_path) if not _is_valid_seg(seg_image): logger.warning( f"Failed to segment {image_path} by SAM, retry with `rembg`." ) # noqa seg_image = rbg_remover(image, save_path=seg_path) if not _is_valid_seg(seg_image): if mode == "strict": raise RuntimeError( f"Failed to segment {image_path} by SAM and rembg, abort." ) logger.warning( f"Failed to segment {image_path} by rembg, use raw image." ) # noqa seg_image = image.convert("RGBA") seg_image.save(seg_path) return seg_image if __name__ == "__main__": args = parse_args() for image_path in args.image_path: try: filename = os.path.basename(image_path).split(".")[0] output_root = args.output_root if args.image_root is not None: output_root = os.path.join(output_root, filename) os.makedirs(output_root, exist_ok=True) mesh_out = f"{output_root}/{filename}.obj" if args.skip_exists and os.path.exists(mesh_out): logger.info( f"Skip {image_path}, already processed in {mesh_out}" ) continue image = Image.open(image_path) image.save(f"{output_root}/{filename}_raw.png") # Segmentation: Get segmented image using SAM or Rembg. seg_path = f"{output_root}/{filename}_cond.png" if image.mode != "RGBA": seg_image = RBG_REMOVER(image, save_path=seg_path) seg_image = trellis_preprocess(seg_image) else: seg_image = image seg_image.save(seg_path) # Run the pipeline try: outputs = PIPELINE.run( seg_image, preprocess_image=False, # Optional parameters # seed=1, # sparse_structure_sampler_params={ # "steps": 12, # "cfg_strength": 7.5, # }, # slat_sampler_params={ # "steps": 12, # "cfg_strength": 3, # }, ) except Exception as e: logger.error( f"[Pipeline Failed] process {image_path}: {e}, skip." ) continue # Render and save color and mesh videos gs_model = outputs["gaussian"][0] mesh_model = outputs["mesh"][0] color_images = render_video(gs_model)["color"] normal_images = render_video(mesh_model)["normal"] video_path = os.path.join(output_root, "gs_mesh.mp4") merge_images_video(color_images, normal_images, video_path) if not args.no_mesh: # Save the raw Gaussian model gs_path = mesh_out.replace(".obj", "_gs.ply") gs_model.save_ply(gs_path) # Rotate mesh and GS by 90 degrees around Z-axis. rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] # Addtional rotation for GS to align mesh. gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix) pose = GaussianOperator.trans_to_quatpose(gs_rot) aligned_gs_path = gs_path.replace(".ply", "_aligned.ply") GaussianOperator.resave_ply( in_ply=gs_path, out_ply=aligned_gs_path, instance_pose=pose, device="cpu", ) color_path = os.path.join(output_root, "color.png") render_gs_api(aligned_gs_path, color_path) mesh = trimesh.Trimesh( vertices=mesh_model.vertices.cpu().numpy(), faces=mesh_model.faces.cpu().numpy(), ) mesh.vertices = mesh.vertices @ np.array(mesh_add_rot) mesh.vertices = mesh.vertices @ np.array(rot_matrix) mesh_obj_path = os.path.join(output_root, f"{filename}.obj") mesh.export(mesh_obj_path) mesh = backproject_api( delight_model=DELIGHT, imagesr_model=IMAGESR_MODEL, color_path=color_path, mesh_path=mesh_obj_path, output_path=mesh_obj_path, skip_fix_mesh=False, delight=True, texture_wh=[2048, 2048], ) mesh_glb_path = os.path.join(output_root, f"{filename}.glb") mesh.export(mesh_glb_path) urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4) asset_attrs = { "version": VERSION, "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply", } if args.height_range: min_height, max_height = map( float, args.height_range.split("-") ) asset_attrs["min_height"] = min_height asset_attrs["max_height"] = max_height if args.mass_range: min_mass, max_mass = map(float, args.mass_range.split("-")) asset_attrs["min_mass"] = min_mass asset_attrs["max_mass"] = max_mass if args.asset_type: asset_attrs["category"] = args.asset_type if args.version: asset_attrs["version"] = args.version urdf_path = urdf_convertor( mesh_path=mesh_obj_path, output_root=f"{output_root}/URDF_{filename}", **asset_attrs, ) # Rescale GS and save to URDF/mesh folder. real_height = urdf_convertor.get_attr_from_urdf( urdf_path, attr_name="real_height" ) out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa GaussianOperator.resave_ply( in_ply=aligned_gs_path, out_ply=out_gs, real_height=real_height, device="cpu", ) # Quality check and update .urdf file. mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb")) # image_paths = render_asset3d( # mesh_path=mesh_out, # output_root=f"{output_root}/URDF_{filename}", # output_subdir="qa_renders", # num_images=8, # elevation=(30, -30), # distance=5.5, # ) image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa image_paths = glob(f"{image_dir}/*.png") images_list = [] for checker in CHECKERS: images = image_paths if isinstance(checker, ImageSegChecker): images = [ f"{output_root}/{filename}_raw.png", f"{output_root}/{filename}_cond.png", ] images_list.append(images) results = BaseChecker.validate(CHECKERS, images_list) urdf_convertor.add_quality_tag(urdf_path, results) except Exception as e: logger.error(f"Failed to process {image_path}: {e}, skip.") continue logger.info(f"Processing complete. Outputs saved to {args.output_root}")