Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import logging | |
import os | |
import sys | |
from glob import glob | |
import numpy as np | |
import trimesh | |
from PIL import Image | |
from asset3d_gen.data.backproject_v2 import entrypoint as backproject_api | |
from asset3d_gen.models.delight_model import DelightingModel | |
from asset3d_gen.models.gs_model import GaussianOperator | |
from asset3d_gen.models.segment_model import ( | |
BMGG14Remover, | |
RembgRemover, | |
SAMPredictor, | |
trellis_preprocess, | |
) | |
from asset3d_gen.models.sr_model import ImageRealESRGAN | |
from asset3d_gen.scripts.render_gs import entrypoint as render_gs_api | |
from asset3d_gen.utils.gpt_clients import GPT_CLIENT | |
from asset3d_gen.utils.process_media import ( | |
merge_images_video, | |
render_asset3d, | |
render_mesh, | |
render_video, | |
) | |
from asset3d_gen.utils.tags import VERSION | |
from asset3d_gen.validators.quality_checkers import ( | |
BaseChecker, | |
ImageAestheticChecker, | |
ImageSegChecker, | |
MeshGeoChecker, | |
) | |
from asset3d_gen.validators.urdf_convertor import URDFGenerator | |
current_file_path = os.path.abspath(__file__) | |
current_dir = os.path.dirname(current_file_path) | |
sys.path.append(os.path.join(current_dir, "../..")) | |
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline | |
from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer | |
from thirdparty.TRELLIS.trellis.representations import ( | |
Gaussian, | |
MeshExtractResult, | |
) | |
from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import ( | |
build_scaling_rotation, | |
inverse_sigmoid, | |
strip_symmetric, | |
) | |
from thirdparty.TRELLIS.trellis.utils import postprocessing_utils | |
from thirdparty.TRELLIS.trellis.utils.render_utils import ( | |
render_frames, | |
yaw_pitch_r_fov_to_extrinsics_intrinsics, | |
) | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO | |
) | |
logger = logging.getLogger(__name__) | |
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( | |
"~/.cache/torch_extensions" | |
) | |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" | |
os.environ["SPCONV_ALGO"] = "native" | |
DELIGHT = DelightingModel() | |
IMAGESR_MODEL = ImageRealESRGAN(outscale=4) | |
RBG_REMOVER = RembgRemover() | |
RBG14_REMOVER = BMGG14Remover() | |
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu") | |
PIPELINE = TrellisImageTo3DPipeline.from_pretrained( | |
"jetx/trellis-image-large" | |
) | |
PIPELINE.cuda() | |
SEG_CHECKER = ImageSegChecker(GPT_CLIENT) | |
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT) | |
AESTHETIC_CHECKER = ImageAestheticChecker() | |
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER] | |
TMP_DIR = os.path.join( | |
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d" | |
) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Image to 3D pipeline args.") | |
parser.add_argument( | |
"--image_path", type=str, nargs="+", help="Path to the input images." | |
) | |
parser.add_argument( | |
"--image_root", type=str, help="Path to the input images folder." | |
) | |
parser.add_argument( | |
"--output_root", | |
type=str, | |
required=True, | |
help="Root directory for saving outputs.", | |
) | |
parser.add_argument( | |
"--no_mesh", action="store_true", help="Do not output mesh files." | |
) | |
parser.add_argument( | |
"--height_range", | |
type=str, | |
default=None, | |
help="The hight in meter to restore the mesh real size.", | |
) | |
parser.add_argument( | |
"--mass_range", | |
type=str, | |
default=None, | |
help="The mass in kg to restore the mesh real weight.", | |
) | |
parser.add_argument("--asset_type", type=str, default=None) | |
parser.add_argument("--skip_exists", action="store_true") | |
parser.add_argument("--strict_seg", action="store_true") | |
parser.add_argument("--version", type=str, default=VERSION) | |
args = parser.parse_args() | |
assert ( | |
args.image_path or args.image_root | |
), "Please provide either --image_path or --image_root." | |
if not args.image_path: | |
args.image_path = glob(os.path.join(args.image_root, "*.png")) | |
args.image_path += glob(os.path.join(args.image_root, "*.jpg")) | |
args.image_path += glob(os.path.join(args.image_root, "*.jpeg")) | |
return args | |
def get_segmented_image( | |
image, | |
sam_remover, | |
rbg_remover, | |
seg_checker, | |
image_path, | |
seg_path, | |
mode="loose", | |
) -> Image.Image: | |
def _is_valid_seg(img: Image.Image) -> bool: | |
return img.mode == "RGBA" and seg_checker([image_path, seg_path])[0] | |
seg_image = sam_remover(image, save_path=seg_path) | |
if not _is_valid_seg(seg_image): | |
logger.warning( | |
f"Failed to segment {image_path} by SAM, retry with `rembg`." | |
) # noqa | |
seg_image = rbg_remover(image, save_path=seg_path) | |
if not _is_valid_seg(seg_image): | |
if mode == "strict": | |
raise RuntimeError( | |
f"Failed to segment {image_path} by SAM and rembg, abort." | |
) | |
logger.warning( | |
f"Failed to segment {image_path} by rembg, use raw image." | |
) # noqa | |
seg_image = image.convert("RGBA") | |
seg_image.save(seg_path) | |
return seg_image | |
if __name__ == "__main__": | |
args = parse_args() | |
for image_path in args.image_path: | |
try: | |
filename = os.path.basename(image_path).split(".")[0] | |
output_root = args.output_root | |
if args.image_root is not None: | |
output_root = os.path.join(output_root, filename) | |
os.makedirs(output_root, exist_ok=True) | |
mesh_out = f"{output_root}/{filename}.obj" | |
if args.skip_exists and os.path.exists(mesh_out): | |
logger.info( | |
f"Skip {image_path}, already processed in {mesh_out}" | |
) | |
continue | |
image = Image.open(image_path) | |
image.save(f"{output_root}/{filename}_raw.png") | |
# Segmentation: Get segmented image using SAM or Rembg. | |
seg_path = f"{output_root}/{filename}_cond.png" | |
if image.mode != "RGBA": | |
seg_image = RBG_REMOVER(image, save_path=seg_path) | |
seg_image = trellis_preprocess(seg_image) | |
else: | |
seg_image = image | |
seg_image.save(seg_path) | |
# Run the pipeline | |
try: | |
outputs = PIPELINE.run( | |
seg_image, | |
preprocess_image=False, | |
# Optional parameters | |
# seed=1, | |
# sparse_structure_sampler_params={ | |
# "steps": 12, | |
# "cfg_strength": 7.5, | |
# }, | |
# slat_sampler_params={ | |
# "steps": 12, | |
# "cfg_strength": 3, | |
# }, | |
) | |
except Exception as e: | |
logger.error( | |
f"[Pipeline Failed] process {image_path}: {e}, skip." | |
) | |
continue | |
# Render and save color and mesh videos | |
gs_model = outputs["gaussian"][0] | |
mesh_model = outputs["mesh"][0] | |
color_images = render_video(gs_model)["color"] | |
normal_images = render_video(mesh_model)["normal"] | |
video_path = os.path.join(output_root, "gs_mesh.mp4") | |
merge_images_video(color_images, normal_images, video_path) | |
if not args.no_mesh: | |
# Save the raw Gaussian model | |
gs_path = mesh_out.replace(".obj", "_gs.ply") | |
gs_model.save_ply(gs_path) | |
# Rotate mesh and GS by 90 degrees around Z-axis. | |
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] | |
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] | |
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] | |
# Addtional rotation for GS to align mesh. | |
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix) | |
pose = GaussianOperator.trans_to_quatpose(gs_rot) | |
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply") | |
GaussianOperator.resave_ply( | |
in_ply=gs_path, | |
out_ply=aligned_gs_path, | |
instance_pose=pose, | |
device="cpu", | |
) | |
color_path = os.path.join(output_root, "color.png") | |
render_gs_api(aligned_gs_path, color_path) | |
mesh = trimesh.Trimesh( | |
vertices=mesh_model.vertices.cpu().numpy(), | |
faces=mesh_model.faces.cpu().numpy(), | |
) | |
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot) | |
mesh.vertices = mesh.vertices @ np.array(rot_matrix) | |
mesh_obj_path = os.path.join(output_root, f"{filename}.obj") | |
mesh.export(mesh_obj_path) | |
mesh = backproject_api( | |
delight_model=DELIGHT, | |
imagesr_model=IMAGESR_MODEL, | |
color_path=color_path, | |
mesh_path=mesh_obj_path, | |
output_path=mesh_obj_path, | |
skip_fix_mesh=False, | |
delight=True, | |
texture_wh=[2048, 2048], | |
) | |
mesh_glb_path = os.path.join(output_root, f"{filename}.glb") | |
mesh.export(mesh_glb_path) | |
urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4) | |
asset_attrs = { | |
"version": VERSION, | |
"gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply", | |
} | |
if args.height_range: | |
min_height, max_height = map( | |
float, args.height_range.split("-") | |
) | |
asset_attrs["min_height"] = min_height | |
asset_attrs["max_height"] = max_height | |
if args.mass_range: | |
min_mass, max_mass = map(float, args.mass_range.split("-")) | |
asset_attrs["min_mass"] = min_mass | |
asset_attrs["max_mass"] = max_mass | |
if args.asset_type: | |
asset_attrs["category"] = args.asset_type | |
if args.version: | |
asset_attrs["version"] = args.version | |
urdf_path = urdf_convertor( | |
mesh_path=mesh_obj_path, | |
output_root=f"{output_root}/URDF_{filename}", | |
**asset_attrs, | |
) | |
# Rescale GS and save to URDF/mesh folder. | |
real_height = urdf_convertor.get_attr_from_urdf( | |
urdf_path, attr_name="real_height" | |
) | |
out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa | |
GaussianOperator.resave_ply( | |
in_ply=aligned_gs_path, | |
out_ply=out_gs, | |
real_height=real_height, | |
device="cpu", | |
) | |
# Quality check and update .urdf file. | |
mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa | |
trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb")) | |
# image_paths = render_asset3d( | |
# mesh_path=mesh_out, | |
# output_root=f"{output_root}/URDF_{filename}", | |
# output_subdir="qa_renders", | |
# num_images=8, | |
# elevation=(30, -30), | |
# distance=5.5, | |
# ) | |
image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa | |
image_paths = glob(f"{image_dir}/*.png") | |
images_list = [] | |
for checker in CHECKERS: | |
images = image_paths | |
if isinstance(checker, ImageSegChecker): | |
images = [ | |
f"{output_root}/{filename}_raw.png", | |
f"{output_root}/{filename}_cond.png", | |
] | |
images_list.append(images) | |
results = BaseChecker.validate(CHECKERS, images_list) | |
urdf_convertor.add_quality_tag(urdf_path, results) | |
except Exception as e: | |
logger.error(f"Failed to process {image_path}: {e}, skip.") | |
continue | |
logger.info(f"Processing complete. Outputs saved to {args.output_root}") | |