stable-fast-3d / __init__.py
jammmmm's picture
Update to latest inference code
77d8010
import base64
import logging
import os
import sys
from contextlib import nullcontext
import comfy.model_management
import folder_paths
import numpy as np
import torch
import trimesh
from PIL import Image
from trimesh.exchange import gltf
sys.path.append(os.path.dirname(__file__))
from sf3d.system import SF3D
from sf3d.utils import resize_foreground
SF3D_CATEGORY = "StableFast3D"
SF3D_MODEL_NAME = "stabilityai/stable-fast-3d"
class StableFast3DLoader:
CATEGORY = SF3D_CATEGORY
FUNCTION = "load"
RETURN_NAMES = ("sf3d_model",)
RETURN_TYPES = ("SF3D_MODEL",)
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
def load(self):
device = comfy.model_management.get_torch_device()
model = SF3D.from_pretrained(
SF3D_MODEL_NAME,
config_name="config.yaml",
weight_name="model.safetensors",
)
model.to(device)
model.eval()
return (model,)
class StableFast3DPreview:
CATEGORY = SF3D_CATEGORY
FUNCTION = "preview"
OUTPUT_NODE = True
RETURN_TYPES = ()
@classmethod
def INPUT_TYPES(s):
return {"required": {"mesh": ("MESH",)}}
def preview(self, mesh):
glbs = []
for m in mesh:
scene = trimesh.Scene(m)
glb_data = gltf.export_glb(scene, include_normals=True)
glb_base64 = base64.b64encode(glb_data).decode("utf-8")
glbs.append(glb_base64)
return {"ui": {"glbs": glbs}}
class StableFast3DSampler:
CATEGORY = SF3D_CATEGORY
FUNCTION = "predict"
RETURN_NAMES = ("mesh",)
RETURN_TYPES = ("MESH",)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("SF3D_MODEL",),
"image": ("IMAGE",),
"foreground_ratio": (
"FLOAT",
{"default": 0.85, "min": 0.0, "max": 1.0, "step": 0.01},
),
"texture_resolution": (
"INT",
{"default": 1024, "min": 512, "max": 2048, "step": 256},
),
},
"optional": {
"mask": ("MASK",),
"remesh": (["none", "triangle", "quad"],),
"vertex_count": (
"INT",
{"default": -1, "min": -1, "max": 20000, "step": 1},
),
},
}
def predict(
s,
model,
image,
mask,
foreground_ratio,
texture_resolution,
remesh="none",
vertex_count=-1,
):
if image.shape[0] != 1:
raise ValueError("Only one image can be processed at a time")
pil_image = Image.fromarray(
torch.clamp(torch.round(255.0 * image[0]), 0, 255)
.type(torch.uint8)
.cpu()
.numpy()
)
if mask is not None:
print("Using Mask")
mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype(
np.uint8
)
mask_pil = Image.fromarray(mask_np, mode="L")
pil_image.putalpha(mask_pil)
else:
if image.shape[3] != 4:
print("No mask or alpha channel detected, Converting to RGBA")
pil_image = pil_image.convert("RGBA")
pil_image = resize_foreground(pil_image, foreground_ratio)
print(remesh)
with torch.no_grad():
with torch.autocast(
device_type="cuda", dtype=torch.bfloat16
) if "cuda" in comfy.model_management.get_torch_device().type else nullcontext():
mesh, glob_dict = model.run_image(
pil_image,
bake_resolution=texture_resolution,
remesh=remesh,
vertex_count=vertex_count,
)
if mesh.vertices.shape[0] == 0:
raise ValueError("No subject detected in the image")
return ([mesh],)
class StableFast3DSave:
CATEGORY = SF3D_CATEGORY
FUNCTION = "save"
OUTPUT_NODE = True
RETURN_TYPES = ()
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"mesh": ("MESH",),
"filename_prefix": ("STRING", {"default": "SF3D"}),
}
}
def __init__(self):
self.type = "output"
def save(self, mesh, filename_prefix):
output_dir = folder_paths.get_output_directory()
glbs = []
for idx, m in enumerate(mesh):
scene = trimesh.Scene(m)
glb_data = gltf.export_glb(scene, include_normals=True)
logging.info(f"Generated GLB model with {len(glb_data)} bytes")
full_output_folder, filename, counter, subfolder, filename_prefix = (
folder_paths.get_save_image_path(filename_prefix, output_dir)
)
filename = filename.replace("%batch_num%", str(idx))
out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb")
with open(out_path, "wb") as f:
f.write(glb_data)
glbs.append(base64.b64encode(glb_data).decode("utf-8"))
return {"ui": {"glbs": glbs}}
NODE_DISPLAY_NAME_MAPPINGS = {
"StableFast3DLoader": "Stable Fast 3D Loader",
"StableFast3DPreview": "Stable Fast 3D Preview",
"StableFast3DSampler": "Stable Fast 3D Sampler",
"StableFast3DSave": "Stable Fast 3D Save",
}
NODE_CLASS_MAPPINGS = {
"StableFast3DLoader": StableFast3DLoader,
"StableFast3DPreview": StableFast3DPreview,
"StableFast3DSampler": StableFast3DSampler,
"StableFast3DSave": StableFast3DSave,
}
WEB_DIRECTORY = "./comfyui"
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]