Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,638 Bytes
daa6779 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
"""
-----------------------------------------------------------------------------
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
-----------------------------------------------------------------------------
"""
import argparse
import glob
import importlib
import os
from datetime import datetime
import cv2
import kiui
import numpy as np
import rembg
import torch
import trimesh
from flow.model import Model
from flow.utils import get_random_color, recenter_foreground
from vae.utils import postprocess_mesh
# PYTHONPATH=. python flow/scripts/infer.py
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
help="config file path",
default="flow.configs.big_parts_strict_pvae",
)
parser.add_argument(
"--ckpt_path",
type=str,
help="checkpoint path",
default="pretrained/flow.pt",
)
parser.add_argument("--input", type=str, help="input directory", default="assets/images/")
parser.add_argument("--limit", type=int, help="limit number of images", default=-1)
parser.add_argument("--output_dir", type=str, help="output directory", default="output/")
parser.add_argument("--grid_res", type=int, help="grid resolution", default=384)
parser.add_argument("--num_steps", type=int, help="number of cfg steps", default=30)
parser.add_argument("--cfg_scale", type=float, help="cfg scale", default=7.0)
parser.add_argument("--num_repeats", type=int, help="number of repeats per image", default=1)
parser.add_argument("--seed", type=int, help="seed", default=42)
args = parser.parse_args()
TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32)
bg_remover = rembg.new_session()
def preprocess_image(path):
input_image = kiui.read_image(path, mode="uint8", order="RGBA")
# bg removal if there is no alpha channel
if input_image.shape[-1] == 3:
input_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
mask = input_image[..., -1] > 0
image = recenter_foreground(input_image, mask, border_ratio=0.1)
image = cv2.resize(image, (518, 518), interpolation=cv2.INTER_LINEAR)
image = image.astype(np.float32) / 255.0
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) # white background
return image
print(f"Loading checkpoint from {args.ckpt_path}")
ckpt_dict = torch.load(args.ckpt_path, weights_only=True)
# delete all keys other than model
if "model" in ckpt_dict:
ckpt_dict = ckpt_dict["model"]
# instantiate model
print(f"Instantiating model from {args.config}")
model_config = importlib.import_module(args.config).make_config()
model = Model(model_config).eval().cuda().bfloat16()
# load weight
print(f"Loading weights from {args.ckpt_path}")
model.load_state_dict(ckpt_dict, strict=True)
# output folder
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
workspace = os.path.join(args.output_dir, "flow_" + args.config.split(".")[-1] + "_" + timestamp)
if not os.path.exists(workspace):
os.makedirs(workspace)
else:
os.system(f"rm {workspace}/*")
print(f"Output directory: {workspace}")
# load test images
if os.path.isdir(args.input):
paths = glob.glob(os.path.join(args.input, "*"))
paths = sorted(paths)
if args.limit > 0:
paths = paths[: args.limit]
else: # single file
paths = [args.input]
for path in paths:
name = os.path.splitext(os.path.basename(path))[0]
print(f"Processing {name}")
image = preprocess_image(path)
kiui.write_image(os.path.join(workspace, name + ".jpg"), image)
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().unsqueeze(0).float().cuda()
# run model
data = {"cond_images": image}
for i in range(args.num_repeats):
kiui.seed_everything(args.seed + i)
with torch.inference_mode():
results = model(data, num_steps=args.num_steps, cfg_scale=args.cfg_scale)
latent = results["latent"]
# kiui.lo(latent)
# query mesh
if model.config.use_parts:
data_part0 = {"latent": latent[:, : model.config.latent_size, :]}
data_part1 = {"latent": latent[:, model.config.latent_size :, :]}
with torch.inference_mode():
results_part0 = model.vae(data_part0, resolution=args.grid_res)
results_part1 = model.vae(data_part1, resolution=args.grid_res)
vertices, faces = results_part0["meshes"][0]
mesh_part0 = trimesh.Trimesh(vertices, faces)
mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T
mesh_part0 = postprocess_mesh(mesh_part0, 5e4)
parts = mesh_part0.split(only_watertight=False)
vertices, faces = results_part1["meshes"][0]
mesh_part1 = trimesh.Trimesh(vertices, faces)
mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T
mesh_part1 = postprocess_mesh(mesh_part1, 5e4)
parts.extend(mesh_part1.split(only_watertight=False))
# split connected components and assign different colors
for j, part in enumerate(parts):
# each component uses a random color
part.visual.vertex_colors = get_random_color(j, use_float=True)
mesh = trimesh.Scene(parts)
# export the whole mesh
mesh.export(os.path.join(workspace, name + "_" + str(i) + ".glb"))
# export each part
for j, part in enumerate(parts):
part.export(os.path.join(workspace, name + "_" + str(i) + "_part" + str(j) + ".glb"))
# export dual volumes
mesh_part0.export(os.path.join(workspace, name + "_" + str(i) + "_vol0.glb"))
mesh_part1.export(os.path.join(workspace, name + "_" + str(i) + "_vol1.glb"))
else:
data = {"latent": latent}
with torch.inference_mode():
results = model.vae(data, resolution=args.grid_res)
vertices, faces = results["meshes"][0]
mesh = trimesh.Trimesh(vertices, faces)
mesh = postprocess_mesh(mesh, 5e4)
# kiui.lo(mesh.vertices, mesh.faces)
mesh.vertices = mesh.vertices @ TRIMESH_GLB_EXPORT.T
mesh.export(os.path.join(workspace, name + "_" + str(i) + ".glb"))
|