import os import tempfile import time from functools import lru_cache from typing import Any import gradio as gr import numpy as np import rembg import torch from gradio_litmodel3d import LitModel3D from PIL import Image import sf3d.utils as sf3d_utils from sf3d.system import SF3D from fastapi import FastAPI, File, UploadFile from fastapi.responses import FileResponse import datetime app = FastAPI() rembg_session = rembg.new_session() COND_WIDTH = 512 COND_HEIGHT = 512 COND_DISTANCE = 1.6 COND_FOVY_DEG = 40 BACKGROUND_COLOR = [0.5, 0.5, 0.5] # Cached. Doesn't change c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE) intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg( COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH ) model = SF3D.from_pretrained( "stabilityai/stable-fast-3d", config_name="config.yaml", weight_name="model.safetensors", ) model.eval().cuda() example_files = [ os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples") ] def run_model(input_image): start = time.time() with torch.no_grad(): with torch.autocast(device_type="cuda", dtype=torch.float16): model_batch = create_batch(input_image) model_batch = {k: v.cuda() for k, v in model_batch.items()} trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024) trimesh_mesh = trimesh_mesh[0] # Create new tmp file tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb") trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True) print("Generation took:", time.time() - start, "s") return tmp_file.name def create_batch(input_image: Image) -> dict[str, Any]: img_cond = ( torch.from_numpy( np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) / 255.0 ) .float() .clip(0, 1) ) mask_cond = img_cond[:, :, -1:] rgb_cond = torch.lerp( torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond ) batch_elem = { "rgb_cond": rgb_cond, "mask_cond": mask_cond, "c2w_cond": c2w_cond.unsqueeze(0), "intrinsic_cond": intrinsic.unsqueeze(0), "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), } # Add batch dim batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()} return batched @lru_cache def checkerboard(squares: int, size: int, min_value: float = 0.5): base = np.zeros((squares, squares)) + min_value base[1::2, ::2] = 1 base[::2, 1::2] = 1 repeat_mult = size // squares return ( base.repeat(repeat_mult, axis=0) .repeat(repeat_mult, axis=1)[:, :, None] .repeat(3, axis=-1) ) def remove_background(input_image: Image) -> Image: return rembg.remove(input_image, session=rembg_session) def resize_foreground( image: Image, ratio: float, ) -> Image: image = np.array(image) assert image.shape[-1] == 4 alpha = np.where(image[..., 3] > 0) y1, y2, x1, x2 = ( alpha[0].min(), alpha[0].max(), alpha[1].min(), alpha[1].max(), ) # crop the foreground fg = image[y1:y2, x1:x2] # pad to square size = max(fg.shape[0], fg.shape[1]) ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 new_image = np.pad( fg, ((ph0, ph1), (pw0, pw1), (0, 0)), mode="constant", constant_values=((0, 0), (0, 0), (0, 0)), ) # compute padding according to the ratio new_size = int(new_image.shape[0] / ratio) # pad to size, double side ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 ph1, pw1 = new_size - size - ph0, new_size - size - pw0 new_image = np.pad( new_image, ((ph0, ph1), (pw0, pw1), (0, 0)), mode="constant", constant_values=((0, 0), (0, 0), (0, 0)), ) new_image = Image.fromarray(new_image, mode="RGBA").resize( (COND_WIDTH, COND_HEIGHT) ) return new_image def square_crop(input_image: Image) -> Image: # Perform a center square crop min_size = min(input_image.size) left = (input_image.size[0] - min_size) // 2 top = (input_image.size[1] - min_size) // 2 right = (input_image.size[0] + min_size) // 2 bottom = (input_image.size[1] + min_size) // 2 return input_image.crop((left, top, right, bottom)).resize( (COND_WIDTH, COND_HEIGHT) ) def show_mask_img(input_image: Image) -> Image: img_numpy = np.array(input_image) alpha = img_numpy[:, :, 3] / 255.0 chkb = checkerboard(32, 512) * 255 new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None]) return Image.fromarray(new_img.astype(np.uint8), mode="RGB") def upload_file_to_s3(file_path, bucket_name, object_name=None): s3_client.upload_file(file_path, bucket_name, object_name) return True @app.post("/process-image/") async def process_image(file: UploadFile = File(...), foreground_ratio: float = 0.85): input_image = Image.open(file.file).convert("RGBA") rem_removed = remove_background(input_image) sqr_crop = square_crop(rem_removed) fr_res = resize_foreground(sqr_crop, foreground_ratio) glb_file = run_model(fr_res) timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S%f') object_name = f'object_{timestamp}.glb' if upload_file_to_s3(glb_file, 'framebucket3d',object_name): return { "glb_path": f"https://framebucket3d.s3.amazonaws.com/{object_name_2}" }