|
import sys |
|
import os |
|
from subprocess import check_call |
|
import tempfile |
|
|
|
from os.path import basename, splitext, join |
|
from io import BytesIO |
|
|
|
import numpy as np |
|
from scipy.spatial import KDTree |
|
from PIL import Image |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision.transforms.functional import to_tensor, to_pil_image |
|
from einops import rearrange |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
|
|
from extern.ZoeDepth.zoedepth.utils.misc import colorize |
|
|
|
from gradio_model3dgscamera import Model3DGSCamera |
|
|
|
IMAGE_SIZE = 512 |
|
NEAR, FAR = 0.01, 100 |
|
FOVY = np.deg2rad(55) |
|
|
|
def download_models(): |
|
models = [ |
|
{ |
|
'repo': 'stabilityai/sd-vae-ft-mse', |
|
'sub': None, |
|
'dst': 'checkpoints/sd-vae-ft-mse', |
|
'files': ['config.json', 'diffusion_pytorch_model.safetensors'], |
|
'token': None |
|
}, |
|
{ |
|
'repo': 'lambdalabs/sd-image-variations-diffusers', |
|
'sub': 'image_encoder', |
|
'dst': 'checkpoints', |
|
'files': ['config.json', 'pytorch_model.bin'], |
|
'token': None |
|
}, |
|
{ |
|
'repo': 'Sony/genwarp', |
|
'sub': 'multi1', |
|
'dst': 'checkpoints', |
|
'files': ['config.json', 'denoising_unet.pth', 'pose_guider.pth', 'reference_unet.pth'], |
|
'token': None |
|
} |
|
] |
|
|
|
for model in models: |
|
for file in model['files']: |
|
hf_hub_download( |
|
repo_id=model['repo'], |
|
subfolder=model['sub'], |
|
filename=file, |
|
local_dir=model['dst'], |
|
token=model['token'] |
|
) |
|
|
|
|
|
def crop(img: Image) -> Image: |
|
W, H = img.size |
|
if W < H: |
|
left, right = 0, W |
|
top, bottom = np.ceil((H - W) / 2.), np.floor((H - W) / 2.) + W |
|
else: |
|
left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H |
|
top, bottom = 0, H |
|
return img.crop((left, top, right, bottom)) |
|
|
|
def unproject(depth): |
|
fovy_deg = 55 |
|
H, W = depth.shape[2:4] |
|
|
|
mean_depth = depth.mean(dim=(2, 3)).squeeze().item() |
|
|
|
viewport_mtx = get_viewport_matrix( |
|
IMAGE_SIZE, IMAGE_SIZE, |
|
batch_size=1 |
|
).to(depth) |
|
|
|
|
|
fovy = torch.ones(1) * FOVY |
|
proj_mtx = get_projection_matrix( |
|
fovy=fovy, |
|
aspect_wh=1., |
|
near=NEAR, |
|
far=FAR |
|
).to(depth) |
|
|
|
view_mtx = camera_lookat( |
|
torch.tensor([[0., 0., 0.]]), |
|
torch.tensor([[0., 0., 1.]]), |
|
torch.tensor([[0., -1., 0.]]) |
|
).to(depth) |
|
|
|
scr_mtx = (viewport_mtx @ proj_mtx).to(depth) |
|
|
|
grid = torch.stack(torch.meshgrid( |
|
torch.arange(W), torch.arange(H), indexing='xy'), dim=-1 |
|
).to(depth)[None] |
|
|
|
screen = F.pad(grid, (0, 1), 'constant', 0) |
|
screen = F.pad(screen, (0, 1), 'constant', 1) |
|
screen_flat = rearrange(screen, 'b h w c -> b (h w) c') |
|
|
|
eye = screen_flat @ torch.linalg.inv_ex( |
|
scr_mtx.float() |
|
)[0].mT.to(depth) |
|
eye = eye * rearrange(depth, 'b c h w -> b (h w) c') |
|
eye[..., 3] = 1 |
|
|
|
points = eye @ torch.linalg.inv_ex(view_mtx.float())[0].mT.to(depth) |
|
points = points[0, :, :3] |
|
|
|
|
|
points[..., 2] -= mean_depth |
|
camera_pos = (0, 0, -mean_depth) |
|
view_mtx = camera_lookat( |
|
torch.tensor([[0., 0., -mean_depth]]), |
|
torch.tensor([[0., 0., 0.]]), |
|
torch.tensor([[0., -1., 0.]]) |
|
).to(depth) |
|
|
|
return points, camera_pos, view_mtx, proj_mtx |
|
|
|
def calc_dist2(points: np.ndarray): |
|
dists, _ = KDTree(points).query(points, k=4) |
|
mean_dists = (dists[:, 1:] ** 2).mean(1) |
|
return mean_dists |
|
|
|
def save_as_splat( |
|
filepath: str, |
|
xyz: np.ndarray, |
|
rgb: np.ndarray |
|
): |
|
|
|
inv_sigmoid = lambda x: np.log(x / (1 - x)) |
|
dist2 = np.clip(calc_dist2(xyz), a_min=0.0000001, a_max=None) |
|
scales = np.repeat(np.log(np.sqrt(dist2))[..., np.newaxis], 3, axis=1) |
|
rots = np.zeros((xyz.shape[0], 4)) |
|
rots[:, 0] = 1 |
|
opacities = inv_sigmoid(0.1 * np.ones((xyz.shape[0], 1))) |
|
|
|
sorted_indices = np.argsort(( |
|
-np.exp(np.sum(scales, axis=-1, keepdims=True)) |
|
/ (1 + np.exp(-opacities)) |
|
).squeeze()) |
|
|
|
buffer = BytesIO() |
|
for idx in sorted_indices: |
|
position = xyz[idx] |
|
scale = np.exp(scales[idx]).astype(np.float32) |
|
rot = rots[idx].astype(np.float32) |
|
color = np.concatenate( |
|
(rgb[idx], 1 / (1 + np.exp(-opacities[idx]))), |
|
axis=-1 |
|
) |
|
buffer.write(position.tobytes()) |
|
buffer.write(scale.tobytes()) |
|
buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes()) |
|
buffer.write( |
|
((rot / np.linalg.norm(rot)) * 128 + 128) |
|
.clip(0, 255) |
|
.astype(np.uint8) |
|
.tobytes() |
|
) |
|
|
|
with open(filepath, "wb") as f: |
|
f.write(buffer.getvalue()) |
|
|
|
def view_from_rt(position, rotation): |
|
t = np.array(position) |
|
euler = np.array(rotation) |
|
|
|
cx = np.cos(euler[0]) |
|
sx = np.sin(euler[0]) |
|
cy = np.cos(euler[1]) |
|
sy = np.sin(euler[1]) |
|
cz = np.cos(euler[2]) |
|
sz = np.sin(euler[2]) |
|
R = np.array([ |
|
cy * cz + sy * sx * sz, |
|
-cy * sz + sy * sx * cz, |
|
sy * cx, |
|
cx * sz, |
|
cx * cz, |
|
-sx, |
|
-sy * cz + cy * sx * sz, |
|
sy * sz + cy * sx * cz, |
|
cy * cx |
|
]) |
|
view_mtx = np.array([ |
|
[R[0], R[1], R[2], 0], |
|
[R[3], R[4], R[5], 0], |
|
[R[6], R[7], R[8], 0], |
|
[ |
|
-t[0] * R[0] - t[1] * R[3] - t[2] * R[6], |
|
-t[0] * R[1] - t[1] * R[4] - t[2] * R[7], |
|
-t[0] * R[2] - t[1] * R[5] - t[2] * R[8], |
|
1 |
|
] |
|
]).T |
|
|
|
B = np.array([ |
|
[1, 0, 0, 0], |
|
[0, -1, 0, 0], |
|
[0, 0, -1, 0], |
|
[0, 0, 0, 1] |
|
]) |
|
return B @ view_mtx |
|
|
|
|
|
|
|
download_models() |
|
|
|
mde = torch.hub.load( |
|
'./extern/ZoeDepth', |
|
'ZoeD_N', |
|
source='local', |
|
pretrained=True, |
|
trust_repo=True |
|
) |
|
|
|
import spaces |
|
|
|
check_call([ |
|
sys.executable, '-m', 'pip', 'install', |
|
'extern/splatting-0.0.1-py3-none-any.whl' |
|
]) |
|
|
|
from genwarp import GenWarp |
|
from genwarp.ops import ( |
|
camera_lookat, get_projection_matrix, get_viewport_matrix |
|
) |
|
|
|
|
|
genwarp_cfg = dict( |
|
pretrained_model_path='checkpoints', |
|
checkpoint_name='multi1', |
|
half_precision_weights=True |
|
) |
|
genwarp_nvs = GenWarp(cfg=genwarp_cfg, device='cpu') |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
with gr.Blocks( |
|
title='GenWarp Demo', |
|
css='img {display: inline;}' |
|
) as demo: |
|
|
|
src_image = gr.State() |
|
src_depth = gr.State() |
|
proj_mtx = gr.State() |
|
src_view_mtx = gr.State() |
|
|
|
|
|
gr.Markdown( |
|
""" |
|
# GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping |
|
[![Project Site](https://img.shields.io/badge/Project-Web-green)](https://genwarp-nvs.github.io/) |
|
[![Spaces](https://img.shields.io/badge/Spaces-Demo-yellow?logo=huggingface)](https://huggingface.co/spaces/Sony/GenWarp) |
|
[![Github](https://img.shields.io/badge/Github-Repo-orange?logo=github)](https://github.com/sony/genwarp/) |
|
[![Models](https://img.shields.io/badge/Models-checkpoints-blue?logo=huggingface)](https://huggingface.co/Sony/genwarp) |
|
[![arXiv](https://img.shields.io/badge/arXiv-2405.17251-red?logo=arxiv)](https://arxiv.org/abs/2405.17251) |
|
|
|
## Introduction |
|
This is an official demo for the paper "[GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping](https://genwarp-nvs.github.io/)". Genwarp can generate novel view images from a single input conditioned on camera poses. In this demo, we offer a basic use of inference of the model. For detailed information, please refer the [paper](https://arxiv.org/abs/2405.17251). |
|
|
|
## How to Use |
|
1. Upload a reference image to "Reference Input" |
|
- You can also select a image from "Examples" |
|
2. Move the camera to your desired view in "Unprojected 3DGS" 3D viewer |
|
3. Hit "Generate a novel view" button and check the result |
|
|
|
""" |
|
) |
|
file = gr.File(label='Reference Input', file_types=['image']) |
|
examples = gr.Examples( |
|
examples=['./assets/pexels-heyho-5998120_19mm.jpg', |
|
'./assets/pexels-itsterrymag-12639296_24mm.jpg'], |
|
inputs=file |
|
) |
|
with gr.Row(): |
|
image_widget = gr.Image( |
|
label='Reference View', type='filepath', |
|
interactive=False |
|
) |
|
depth_widget = gr.Image(label='Estimated Depth', type='pil') |
|
viewer = Model3DGSCamera( |
|
label = 'Unprojected 3DGS', |
|
width=IMAGE_SIZE, |
|
height=IMAGE_SIZE, |
|
camera_width=IMAGE_SIZE, |
|
camera_height=IMAGE_SIZE, |
|
camera_fx=IMAGE_SIZE / (np.tan(FOVY / 2.)) / 2., |
|
camera_fy=IMAGE_SIZE / (np.tan(FOVY / 2.)) / 2., |
|
camera_near=NEAR, |
|
camera_far=FAR |
|
) |
|
button = gr.Button('Generate a novel view', size='lg', variant='primary') |
|
with gr.Row(): |
|
warped_widget = gr.Image( |
|
label='Warped Image', type='pil', interactive=False |
|
) |
|
gen_widget = gr.Image( |
|
label='Generated View', type='pil', interactive=False |
|
) |
|
|
|
|
|
@spaces.GPU |
|
def cb_mde(image_file: str): |
|
image = to_tensor(crop(Image.open( |
|
image_file |
|
).convert('RGB')).resize((IMAGE_SIZE, IMAGE_SIZE)))[None].cuda() |
|
depth = mde.cuda().infer(image) |
|
depth_image = to_pil_image(colorize(depth[0])) |
|
return to_pil_image(image[0]), depth_image, image.cpu().detach(), depth.cpu().detach() |
|
|
|
@spaces.GPU |
|
def cb_3d(image, depth, image_file): |
|
xyz, camera_pos, view_mtx, proj_mtx = unproject(depth.cuda()) |
|
rgb = rearrange(image, 'b c h w -> b (h w) c')[0] |
|
splat_file = join(tmpdir, f'./{splitext(basename(image_file))[0]}.splat') |
|
save_as_splat(splat_file, xyz.cpu().detach().numpy(), rgb.cpu().detach().numpy()) |
|
return (splat_file, camera_pos, None), view_mtx.cpu().detach(), proj_mtx.cpu().detach() |
|
|
|
@spaces.GPU |
|
def cb_generate(viewer, image, depth, src_view_mtx, proj_mtx): |
|
image = image.cuda() |
|
depth = depth.cuda() |
|
src_view_mtx = src_view_mtx.cuda() |
|
proj_mtx = proj_mtx.cuda() |
|
src_camera_pos = viewer[1] |
|
src_camera_rot = viewer[2] |
|
tar_view_mtx = view_from_rt(src_camera_pos, src_camera_rot) |
|
tar_view_mtx = torch.from_numpy(tar_view_mtx).to(image) |
|
rel_view_mtx = ( |
|
tar_view_mtx @ torch.linalg.inv(src_view_mtx.to(image)) |
|
).to(image) |
|
|
|
|
|
renders = genwarp_nvs.to('cuda')( |
|
src_image=image.half(), |
|
src_depth=depth.half(), |
|
rel_view_mtx=rel_view_mtx.half(), |
|
src_proj_mtx=proj_mtx.half(), |
|
tar_proj_mtx=proj_mtx.half() |
|
) |
|
|
|
warped = renders['warped'] |
|
synthesized = renders['synthesized'] |
|
warped_pil = to_pil_image(warped[0]) |
|
synthesized_pil = to_pil_image(synthesized[0]) |
|
|
|
return warped_pil, synthesized_pil |
|
|
|
|
|
file.change( |
|
fn=cb_mde, |
|
inputs=file, |
|
outputs=[image_widget, depth_widget, src_image, src_depth] |
|
).then( |
|
fn=cb_3d, |
|
inputs=[src_image, src_depth, image_widget], |
|
outputs=[viewer, src_view_mtx, proj_mtx]) |
|
button.click( |
|
fn=cb_generate, |
|
inputs=[viewer, src_image, src_depth, src_view_mtx, proj_mtx], |
|
outputs=[warped_widget, gen_widget]) |
|
|
|
if __name__ == '__main__': |
|
demo.launch() |
|
|