|
import os |
|
import cv2 |
|
import json |
|
import time |
|
import copy |
|
import tqdm |
|
import rembg |
|
import trimesh |
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import pandas as pd |
|
|
|
from kiui.lpips import LPIPS |
|
|
|
import sys |
|
sys.path.append('./') |
|
|
|
from sparseags.cam_utils import orbit_camera, OrbitCamera, mat2latlon, find_mask_center_and_translate |
|
from sparseags.render_utils.gs_renderer import CustomCamera |
|
from sparseags.mesh_utils.mesh_renderer import Renderer |
|
|
|
|
|
class GUI: |
|
def __init__(self, opt): |
|
self.opt = opt |
|
self.gui = opt.gui |
|
self.W = opt.W |
|
self.H = opt.H |
|
|
|
self.mode = "image" |
|
self.seed = 0 |
|
|
|
self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32) |
|
self.need_update = True |
|
|
|
|
|
self.device = torch.device("cuda") |
|
self.bg_remover = None |
|
|
|
self.guidance_sd = None |
|
self.guidance_zero123 = None |
|
self.guidance_dino = None |
|
|
|
self.enable_sd = False |
|
self.enable_zero123 = False |
|
self.enable_dino = False |
|
|
|
|
|
self.renderer = Renderer(opt).to(self.device) |
|
|
|
|
|
self.input_img = None |
|
self.input_mask = None |
|
self.input_img_torch = None |
|
self.input_mask_torch = None |
|
self.overlay_input_img = False |
|
self.overlay_input_img_ratio = 0.5 |
|
|
|
|
|
self.prompt = "" |
|
self.negative_prompt = "" |
|
|
|
|
|
self.training = False |
|
self.optimizer = None |
|
self.step = 0 |
|
self.train_steps = 1 |
|
|
|
|
|
self.load_input(self.opt.camera_path, self.opt.order_path) |
|
|
|
|
|
if self.opt.prompt is not None: |
|
self.prompt = self.opt.prompt |
|
if self.opt.negative_prompt is not None: |
|
self.negative_prompt = self.opt.negative_prompt |
|
|
|
def seed_everything(self): |
|
try: |
|
seed = int(self.seed) |
|
except: |
|
seed = np.random.randint(0, 1000000) |
|
|
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = True |
|
|
|
self.last_seed = seed |
|
|
|
def prepare_train(self): |
|
|
|
self.step = 0 |
|
|
|
|
|
self.optimizer = torch.optim.Adam(self.renderer.get_params()) |
|
|
|
cameras = [CustomCamera(v, index=int(k)) for k, v in self.cam_params.items() if v["flag"]] |
|
cam_centers = [mat2latlon(cam.camera_center) for cam in cameras] |
|
self.opt.ref_polars = [float(cam[0]) for cam in cam_centers] |
|
self.opt.ref_azimuths = [float(cam[1]) for cam in cam_centers] |
|
self.opt.ref_radii = [float(cam[2]) for cam in cam_centers] |
|
self.cams = [(cam.c2w, cam.perspective, cam.focal_length) for cam in cameras] |
|
self.cam = copy.deepcopy(cameras[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.pp_pools = torch.full((2, 8, 2), 128) |
|
|
|
|
|
self.cam.fx = np.array([cam.fx for cam in cameras], dtype=np.float32).mean() |
|
self.cam.fy = np.array([cam.fy for cam in cameras], dtype=np.float32).mean() |
|
self.cam.cx = np.array([cam.cx for cam in cameras], dtype=np.float32).mean() |
|
self.cam.cy = np.array([cam.cy for cam in cameras], dtype=np.float32).mean() |
|
|
|
self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != "" |
|
self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None |
|
self.enable_dino = self.opt.lambda_dino > 0 |
|
|
|
|
|
if self.guidance_sd is None and self.enable_sd: |
|
if self.opt.mvdream: |
|
print(f"[INFO] loading MVDream...") |
|
from guidance.mvdream_utils import MVDream |
|
self.guidance_sd = MVDream(self.device) |
|
print(f"[INFO] loaded MVDream!") |
|
else: |
|
print(f"[INFO] loading SD...") |
|
from guidance.sd_utils import StableDiffusion |
|
self.guidance_sd = StableDiffusion(self.device) |
|
print(f"[INFO] loaded SD!") |
|
|
|
if self.guidance_zero123 is None and self.enable_zero123: |
|
print(f"[INFO] loading zero123...") |
|
from sparseags.guidance_utils.zero123_6d_utils import Zero123 |
|
self.guidance_zero123 = Zero123(self.device, model_key='ashawkey/zero123-xl-diffusers') |
|
print(f"[INFO] loaded zero123!") |
|
|
|
if self.guidance_dino is None and self.enable_dino: |
|
print(f"[INFO] loading dino...") |
|
from guidance.dino_utils import Dino |
|
self.guidance_dino = Dino(self.device, n_components=36, model_key="dinov2_vits14") |
|
self.guidance_dino.fit_pca(self.all_input_images) |
|
print(f"[INFO] loaded dino!") |
|
|
|
|
|
if self.input_img is not None: |
|
self.input_img_torch = torch.from_numpy(self.input_img).permute(0, 3, 1, 2).to(self.device) |
|
self.input_img_torch = F.interpolate(self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False) |
|
|
|
self.input_mask_torch = torch.from_numpy(self.input_mask).permute(0, 3, 1, 2).to(self.device) |
|
self.input_mask_torch = F.interpolate(self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False) |
|
self.input_img_torch_channel_last = self.input_img_torch.permute(0, 2, 3, 1).contiguous() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
if self.enable_sd: |
|
self.guidance_sd.get_text_embeds([self.prompt], [self.negative_prompt]) |
|
|
|
if self.enable_zero123: |
|
self.guidance_zero123.get_img_embeds(self.input_img_torch) |
|
|
|
if self.enable_dino: |
|
self.guidance_dino.embeddings = self.guidance_dino.get_dino_embeds(self.input_img_torch, upscale=True, reduced=True, learned_up=True) |
|
|
|
def train_step(self): |
|
starter = torch.cuda.Event(enable_timing=True) |
|
ender = torch.cuda.Event(enable_timing=True) |
|
starter.record() |
|
|
|
|
|
for _ in range(self.train_steps): |
|
|
|
self.step += 1 |
|
step_ratio = min(1, self.step / self.opt.iters_refine) |
|
|
|
loss = 0 |
|
|
|
|
|
for choice in range(self.num_views): |
|
ssaa = min(2.0, max(0.125, 2 * np.random.random())) |
|
out = self.renderer.render(*self.cams[choice][:2], self.opt.ref_size, self.opt.ref_size, ssaa=ssaa) |
|
|
|
|
|
image = out["image"] |
|
valid_mask = (out["alpha"] > 0).detach() |
|
loss = loss + F.mse_loss(image * valid_mask, self.input_img_torch_channel_last[choice] * valid_mask) |
|
|
|
if self.enable_dino: |
|
feature = out["feature"] |
|
loss = loss + F.mse_loss(feature * valid_mask, self.guidance_dino.embeddings[choice] * valid_mask) |
|
|
|
|
|
render_resolution = 512 |
|
images = [] |
|
vers, hors, radii = [], [], [] |
|
|
|
|
|
|
|
|
|
|
|
min_ver = max(-30 + np.array(self.opt.ref_polars).min(), -80) |
|
max_ver = min(30 + np.array(self.opt.ref_polars).max(), 80) |
|
|
|
for _ in range(self.opt.batch_size): |
|
|
|
|
|
ver = np.random.randint(min_ver, max_ver) - self.opt.ref_polars[0] |
|
hor = np.random.randint(-180, 180) |
|
radius = 0 |
|
|
|
vers.append(ver) |
|
hors.append(hor) |
|
radii.append(radius) |
|
|
|
pose = orbit_camera(self.opt.ref_polars[0] + ver, self.opt.ref_azimuths[0] + hor, np.array(self.opt.ref_radii).mean() + radius) |
|
|
|
|
|
ssaa = min(2.0, max(0.125, 2 * np.random.random())) |
|
|
|
|
|
|
|
|
|
|
|
idx_ver, idx_hor = int((self.opt.ref_polars[0]+ver) < 0), hor // 45 |
|
|
|
flag = 0 |
|
cx, cy = self.pp_pools[idx_ver, idx_hor+4].tolist() |
|
cnt = 0 |
|
|
|
while not flag: |
|
|
|
self.cam.cx = cx |
|
self.cam.cy = cy |
|
|
|
if cnt >= 5: |
|
print(f"[ERROR] Something must be wrong here!") |
|
break |
|
|
|
|
|
out = self.renderer.render(pose, self.cam.perspective, render_resolution, render_resolution, ssaa=ssaa) |
|
|
|
image = out["image"] |
|
image = image.permute(2, 0, 1).contiguous().unsqueeze(0) |
|
mask = out["alpha"] > 0 |
|
mask = mask.permute(2, 0, 1).contiguous().unsqueeze(0) |
|
delta_xy = find_mask_center_and_translate(image.detach(), mask.detach()) / render_resolution * 256 |
|
|
|
if delta_xy[0].abs() > 10 or delta_xy[1].abs() > 10: |
|
cx -= delta_xy[0] |
|
cy -= delta_xy[1] |
|
self.pp_pools[idx_ver, idx_hor+4] = torch.tensor([cx, cy]) |
|
cnt += 1 |
|
else: |
|
flag = 1 |
|
|
|
images.append(image) |
|
|
|
images = torch.cat(images, dim=0) |
|
|
|
|
|
strength = step_ratio * 0.15 + 0.8 |
|
if self.enable_zero123: |
|
v1 = torch.stack([torch.tensor([radius]) + self.opt.ref_radii[0], torch.deg2rad(torch.tensor([ver]) + self.opt.ref_polars[0]), torch.deg2rad(torch.tensor([hor]) + self.opt.ref_azimuths[0])], dim=-1) |
|
v2 = torch.stack([torch.tensor(self.opt.ref_radii), torch.deg2rad(torch.tensor(self.opt.ref_polars)), torch.deg2rad(torch.tensor(self.opt.ref_azimuths))], dim=-1) |
|
angles = torch.rad2deg(self.guidance_zero123.angle_between(v1, v2)).to(self.device) |
|
choice = torch.argmin(angles.squeeze()).item() |
|
|
|
cond_RT = { |
|
"c2w": self.cams[choice][0], |
|
"focal_length": self.cams[choice][-1], |
|
} |
|
target_RT = { |
|
"c2w": pose, |
|
"focal_length": np.array(self.cam.fx, self.cam.fy), |
|
} |
|
cam_embed = self.guidance_zero123.get_cam_embeddings_6D(target_RT, cond_RT) |
|
|
|
|
|
refined_images = self.guidance_zero123.refine(images, cam_embed, strength=strength, idx=choice).float() |
|
refined_images = F.interpolate(refined_images, (render_resolution, render_resolution), mode="bilinear", align_corners=False) |
|
loss = loss + self.opt.lambda_zero123 * F.mse_loss(images, refined_images) |
|
|
|
if self.enable_dino: |
|
loss_dino = self.guidance_dino.train_step( |
|
images, |
|
out["feature"].permute(2, 0, 1).contiguous(), |
|
step_ratio=step_ratio if self.opt.anneal_timestep else None |
|
) |
|
loss = loss + self.opt.lambda_dino * loss_dino |
|
|
|
|
|
loss.backward() |
|
self.optimizer.step() |
|
self.optimizer.zero_grad() |
|
|
|
ender.record() |
|
torch.cuda.synchronize() |
|
t = starter.elapsed_time(ender) |
|
|
|
self.need_update = True |
|
|
|
def load_input(self, camera_path, order_path=None): |
|
|
|
print(f'[INFO] load data from {camera_path}...') |
|
|
|
if order_path is not None: |
|
with open(order_path, 'r') as f: |
|
indices = json.load(f) |
|
else: |
|
indices = None |
|
|
|
with open(camera_path, 'r') as f: |
|
data = json.load(f) |
|
|
|
self.cam_params = {} |
|
for k, v in data.items(): |
|
if indices is None: |
|
self.cam_params[k] = data[k] |
|
else: |
|
if int(k) in indices or k in indices: |
|
self.cam_params[k] = data[k] |
|
|
|
if self.opt.all_views: |
|
v['flag'] = 1 |
|
|
|
img_paths = [v["filepath"] for k, v in self.cam_params.items() if v["flag"]] |
|
self.num_views = len(img_paths) |
|
print(f"[INFO] Number of views: {self.num_views}") |
|
|
|
for filepath in img_paths: |
|
print(filepath) |
|
|
|
images, masks = [], [] |
|
|
|
for i in range(len(img_paths)): |
|
img = cv2.imread(img_paths[i], cv2.IMREAD_UNCHANGED) |
|
if img.shape[-1] == 3: |
|
if self.bg_remover is None: |
|
self.bg_remover = rembg.new_session() |
|
img = rembg.remove(img, session=self.bg_remover) |
|
|
|
img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA) |
|
img = img.astype(np.float32) / 255.0 |
|
|
|
input_mask = img[..., 3:] |
|
|
|
input_img = img[..., :3] * input_mask + (1 - input_mask) |
|
|
|
input_img = input_img[..., ::-1].copy() |
|
|
|
images.append(input_img), masks.append(input_mask) |
|
|
|
images = np.stack(images, axis=0) |
|
masks = np.stack(masks, axis=0) |
|
self.input_img = images[:self.num_views] |
|
self.input_mask = masks[:self.num_views] |
|
self.all_input_images = images |
|
|
|
def save_model(self): |
|
os.makedirs(self.opt.outdir, exist_ok=True) |
|
|
|
path = os.path.join(self.opt.outdir, self.opt.save_path + '.' + self.opt.mesh_format) |
|
self.renderer.export_mesh(path) |
|
|
|
print(f"[INFO] save model to {path}.") |
|
|
|
|
|
def train(self, iters=500): |
|
if iters > 0: |
|
self.prepare_train() |
|
for i in tqdm.trange(iters): |
|
self.train_step() |
|
|
|
self.save_model() |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
from omegaconf import OmegaConf |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", required=True, help="path to the yaml config file") |
|
args, extras = parser.parse_known_args() |
|
|
|
|
|
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras)) |
|
|
|
|
|
if opt.mesh is None: |
|
default_path = os.path.join(opt.outdir, opt.save_path + '_mesh.' + opt.mesh_format) |
|
if os.path.exists(default_path): |
|
opt.mesh = default_path |
|
else: |
|
raise ValueError(f"Cannot find mesh from {default_path}, must specify --mesh explicitly!") |
|
|
|
gui = GUI(opt) |
|
|
|
gui.train(opt.iters_refine) |
|
|