import gradio as gr import spaces from gradio_litmodel3d import LitModel3D import os import shutil os.environ['SPCONV_ALGO'] = 'native' from typing import * import torch import numpy as np import imageio from PIL import Image from trellis.pipelines import TrellisImageTo3DPipeline from trellis.utils import render_utils import trimesh import tempfile MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') os.makedirs(TMP_DIR, exist_ok=True) def preprocess_mesh(mesh_prompt): print("Processing mesh") trimesh_mesh = trimesh.load_mesh(mesh_prompt) trimesh_mesh.export(mesh_prompt+'.glb') return mesh_prompt+'.glb' def preprocess_image(image): if image is None: return None image = pipeline.preprocess_image(image, resolution=1024) return image @spaces.GPU def generate_3d(image, seed=-1, ss_guidance_strength=3, ss_sampling_steps=50, slat_guidance_strength=3, slat_sampling_steps=6,): if image is None: return None, None, None if seed == -1: seed = np.random.randint(0, MAX_SEED) image = pipeline.preprocess_image(image, resolution=1024) normal_image = normal_predictor(image, resolution=768, match_input_resolution=True, data_type='object') outputs = pipeline.run( normal_image, seed=seed, formats=["mesh",], preprocess_image=False, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength, }, slat_sampler_params={ "steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength, }, ) generated_mesh = outputs['mesh'][0] # Save outputs import datetime output_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S") os.makedirs(os.path.join(TMP_DIR, output_id), exist_ok=True) mesh_path = f"{TMP_DIR}/{output_id}/mesh.glb" render_results = render_utils.render_video(generated_mesh, resolution=1024, ssaa=1, num_frames=8, pitch=0.25, inverse_direction=True) def combine_diagonal(color_np, normal_np): # Convert images to numpy arrays h, w, c = color_np.shape # Create a boolean mask that is True for pixels where x > y (diagonally) mask = np.fromfunction(lambda y, x: x > y, (h, w)) mask = mask.astype(bool) mask = np.stack([mask] * c, axis=-1) # Where mask is True take color, else normal combined_np = np.where(mask, color_np, normal_np) return Image.fromarray(combined_np) preview_images = [combine_diagonal(c, n) for c, n in zip(render_results['color'], render_results['normal'])] # Export mesh trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True) trimesh_mesh.export(mesh_path) return preview_images, normal_image, mesh_path, mesh_path def convert_mesh(mesh_path, export_format): """Download the mesh in the selected format.""" if not mesh_path: return None # Create a temporary file to store the mesh data temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False) temp_file_path = temp_file.name new_mesh_path = mesh_path.replace(".glb", f".{export_format}") mesh = trimesh.load_mesh(mesh_path) mesh.export(temp_file_path) # Export to the temporary file return temp_file_path # Return the path to the temporary file # Create the Gradio interface with improved layout with gr.Blocks(css="footer {visibility: hidden}") as demo: gr.Markdown( """
V0.1, Introduced By GAP Lab from CUHKSZ and Game-AIGC Team from ByteDance
""" ) with gr.Row(): gr.Markdown(""" """) with gr.Row(): with gr.Column(scale=1): with gr.Tabs(): with gr.Tab("Single Image"): with gr.Row(): image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil") normal_output = gr.Image(label="Normal Bridge", image_mode="RGBA", type="pil") with gr.Tab("Multiple Images"): gr.Markdown("