cronos3k's picture
Update app.py
b1b52ab verified
raw
history blame
5.54 kB
# ... (previous imports remain the same) ...
@spaces.GPU
def image_to_3d(
image: Image.Image,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
req: gr.Request,
progress: gr.Progress = gr.Progress()
) -> Tuple[dict, str, str, str]:
"""
Convert an image to a 3D model with improved memory management and progress tracking.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
progress(0, desc="Initializing...")
# Clear CUDA cache before starting
torch.cuda.empty_cache()
try:
# Generate 3D model with progress updates
progress(0.1, desc="Running 3D generation pipeline...")
outputs = pipeline.run(
image,
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
progress(0.4, desc="Generating video preview...")
# Generate video frames in batches to manage memory
batch_size = 30 # Process 30 frames at a time
num_frames = 120
video = []
video_geo = []
for i in range(0, num_frames, batch_size):
end_idx = min(i + batch_size, num_frames)
batch_frames = render_utils.render_video(
outputs['gaussian'][0],
num_frames=end_idx - i,
start_frame=i
)['color']
batch_geo = render_utils.render_video(
outputs['mesh'][0],
num_frames=end_idx - i,
start_frame=i
)['normal']
video.extend(batch_frames)
video_geo.extend(batch_geo)
# Clear cache after each batch
torch.cuda.empty_cache()
progress(0.4 + (0.3 * i / num_frames), desc=f"Rendering frames {i} to {end_idx}...")
# Combine video frames
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
# Generate unique ID and save video
trial_id = str(uuid.uuid4())
video_path = os.path.join(user_dir, f"{trial_id}.mp4")
progress(0.7, desc="Saving video...")
imageio.mimsave(video_path, video, fps=15)
# Clear video data from memory
del video
del video_geo
torch.cuda.empty_cache()
# Generate and save full-quality GLB
progress(0.8, desc="Generating full-quality GLB...")
glb = postprocessing_utils.to_glb(
outputs['gaussian'][0],
outputs['mesh'][0],
simplify=0.0,
texture_size=2048,
verbose=False
)
glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
progress(0.9, desc="Saving GLB file...")
glb.export(glb_path)
# Pack state for reduced version
progress(0.95, desc="Finalizing...")
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
# Final cleanup
torch.cuda.empty_cache()
progress(1.0, desc="Complete!")
return state, video_path, glb_path, glb_path
except Exception as e:
# Clean up on error
torch.cuda.empty_cache()
raise gr.Error(f"Processing failed: {str(e)}")
@spaces.GPU
def extract_reduced_glb(
state: dict,
mesh_simplify: float,
texture_size: int,
req: gr.Request,
progress: gr.Progress = gr.Progress()
) -> Tuple[str, str]:
"""
Extract a reduced-quality GLB file with progress tracking.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
try:
progress(0.1, desc="Unpacking model state...")
gs, mesh, trial_id = unpack_state(state)
progress(0.3, desc="Generating reduced GLB...")
glb = postprocessing_utils.to_glb(
gs, mesh,
simplify=mesh_simplify,
texture_size=texture_size,
verbose=False
)
progress(0.8, desc="Saving reduced GLB...")
glb_path = os.path.join(user_dir, f"{trial_id}_reduced.glb")
glb.export(glb_path)
progress(0.9, desc="Cleaning up...")
torch.cuda.empty_cache()
progress(1.0, desc="Complete!")
return glb_path, glb_path
except Exception as e:
torch.cuda.empty_cache()
raise gr.Error(f"GLB reduction failed: {str(e)}")
# ... (rest of the UI code remains the same) ...
# Add some memory optimization settings at startup
if __name__ == "__main__":
# Set some CUDA memory management options
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
# Initialize pipeline
pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
pipeline.cuda()
try:
# Preload rembg with minimal memory usage
test_img = np.zeros((256, 256, 3), dtype=np.uint8) # Smaller test image
pipeline.preprocess_image(Image.fromarray(test_img))
del test_img
torch.cuda.empty_cache()
except:
pass
demo.launch()