Spaces:
Running
on
Zero
Running
on
Zero
""" | |
----------------------------------------------------------------------------- | |
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | |
NVIDIA CORPORATION and its licensors retain all intellectual property | |
and proprietary rights in and to this software, related documentation | |
and any modifications thereto. Any use, reproduction, disclosure or | |
distribution of this software and related documentation without an express | |
license agreement from NVIDIA CORPORATION is strictly prohibited. | |
----------------------------------------------------------------------------- | |
""" | |
import argparse | |
import glob | |
import importlib | |
import os | |
from datetime import datetime | |
import cv2 | |
import kiui | |
import numpy as np | |
import rembg | |
import torch | |
import trimesh | |
from flow.model import Model | |
from flow.utils import get_random_color, recenter_foreground | |
from vae.utils import postprocess_mesh | |
# PYTHONPATH=. python flow/scripts/infer.py | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--config", | |
type=str, | |
help="config file path", | |
default="flow.configs.big_parts_strict_pvae", | |
) | |
parser.add_argument( | |
"--ckpt_path", | |
type=str, | |
help="checkpoint path", | |
default="pretrained/flow.pt", | |
) | |
parser.add_argument("--input", type=str, help="input directory", default="assets/images/") | |
parser.add_argument("--limit", type=int, help="limit number of images", default=-1) | |
parser.add_argument("--output_dir", type=str, help="output directory", default="output/") | |
parser.add_argument("--grid_res", type=int, help="grid resolution", default=384) | |
parser.add_argument("--num_steps", type=int, help="number of cfg steps", default=30) | |
parser.add_argument("--cfg_scale", type=float, help="cfg scale", default=7.0) | |
parser.add_argument("--num_repeats", type=int, help="number of repeats per image", default=1) | |
parser.add_argument("--seed", type=int, help="seed", default=42) | |
args = parser.parse_args() | |
TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32) | |
bg_remover = rembg.new_session() | |
def preprocess_image(path): | |
input_image = kiui.read_image(path, mode="uint8", order="RGBA") | |
# bg removal if there is no alpha channel | |
if input_image.shape[-1] == 3: | |
input_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4] | |
mask = input_image[..., -1] > 0 | |
image = recenter_foreground(input_image, mask, border_ratio=0.1) | |
image = cv2.resize(image, (518, 518), interpolation=cv2.INTER_LINEAR) | |
image = image.astype(np.float32) / 255.0 | |
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) # white background | |
return image | |
print(f"Loading checkpoint from {args.ckpt_path}") | |
ckpt_dict = torch.load(args.ckpt_path, weights_only=True) | |
# delete all keys other than model | |
if "model" in ckpt_dict: | |
ckpt_dict = ckpt_dict["model"] | |
# instantiate model | |
print(f"Instantiating model from {args.config}") | |
model_config = importlib.import_module(args.config).make_config() | |
model = Model(model_config).eval().cuda().bfloat16() | |
# load weight | |
print(f"Loading weights from {args.ckpt_path}") | |
model.load_state_dict(ckpt_dict, strict=True) | |
# output folder | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
workspace = os.path.join(args.output_dir, "flow_" + args.config.split(".")[-1] + "_" + timestamp) | |
if not os.path.exists(workspace): | |
os.makedirs(workspace) | |
else: | |
os.system(f"rm {workspace}/*") | |
print(f"Output directory: {workspace}") | |
# load test images | |
if os.path.isdir(args.input): | |
paths = glob.glob(os.path.join(args.input, "*")) | |
paths = sorted(paths) | |
if args.limit > 0: | |
paths = paths[: args.limit] | |
else: # single file | |
paths = [args.input] | |
for path in paths: | |
name = os.path.splitext(os.path.basename(path))[0] | |
print(f"Processing {name}") | |
image = preprocess_image(path) | |
kiui.write_image(os.path.join(workspace, name + ".jpg"), image) | |
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().unsqueeze(0).float().cuda() | |
# run model | |
data = {"cond_images": image} | |
for i in range(args.num_repeats): | |
kiui.seed_everything(args.seed + i) | |
with torch.inference_mode(): | |
results = model(data, num_steps=args.num_steps, cfg_scale=args.cfg_scale) | |
latent = results["latent"] | |
# kiui.lo(latent) | |
# query mesh | |
if model.config.use_parts: | |
data_part0 = {"latent": latent[:, : model.config.latent_size, :]} | |
data_part1 = {"latent": latent[:, model.config.latent_size :, :]} | |
with torch.inference_mode(): | |
results_part0 = model.vae(data_part0, resolution=args.grid_res) | |
results_part1 = model.vae(data_part1, resolution=args.grid_res) | |
vertices, faces = results_part0["meshes"][0] | |
mesh_part0 = trimesh.Trimesh(vertices, faces) | |
mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T | |
mesh_part0 = postprocess_mesh(mesh_part0, 5e4) | |
parts = mesh_part0.split(only_watertight=False) | |
vertices, faces = results_part1["meshes"][0] | |
mesh_part1 = trimesh.Trimesh(vertices, faces) | |
mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T | |
mesh_part1 = postprocess_mesh(mesh_part1, 5e4) | |
parts.extend(mesh_part1.split(only_watertight=False)) | |
# split connected components and assign different colors | |
for j, part in enumerate(parts): | |
# each component uses a random color | |
part.visual.vertex_colors = get_random_color(j, use_float=True) | |
mesh = trimesh.Scene(parts) | |
# export the whole mesh | |
mesh.export(os.path.join(workspace, name + "_" + str(i) + ".glb")) | |
# export each part | |
for j, part in enumerate(parts): | |
part.export(os.path.join(workspace, name + "_" + str(i) + "_part" + str(j) + ".glb")) | |
# export dual volumes | |
mesh_part0.export(os.path.join(workspace, name + "_" + str(i) + "_vol0.glb")) | |
mesh_part1.export(os.path.join(workspace, name + "_" + str(i) + "_vol1.glb")) | |
else: | |
data = {"latent": latent} | |
with torch.inference_mode(): | |
results = model.vae(data, resolution=args.grid_res) | |
vertices, faces = results["meshes"][0] | |
mesh = trimesh.Trimesh(vertices, faces) | |
mesh = postprocess_mesh(mesh, 5e4) | |
# kiui.lo(mesh.vertices, mesh.faces) | |
mesh.vertices = mesh.vertices @ TRIMESH_GLB_EXPORT.T | |
mesh.export(os.path.join(workspace, name + "_" + str(i) + ".glb")) | |