|
import math |
|
from dataclasses import dataclass |
|
from typing import List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import trimesh |
|
from PIL import Image |
|
from torch import BoolTensor, FloatTensor |
|
|
|
LIST_TYPE = Union[list, np.ndarray, torch.Tensor] |
|
|
|
|
|
def list_to_pt( |
|
x: LIST_TYPE, dtype: Optional[torch.dtype] = None, device: Optional[str] = None |
|
) -> torch.Tensor: |
|
if isinstance(x, list) or isinstance(x, np.ndarray): |
|
return torch.tensor(x, dtype=dtype, device=device) |
|
return x.to(dtype=dtype) |
|
|
|
|
|
def get_c2w( |
|
elevation_deg: LIST_TYPE, |
|
distance: LIST_TYPE, |
|
azimuth_deg: Optional[LIST_TYPE], |
|
num_views: Optional[int] = 1, |
|
device: Optional[str] = None, |
|
) -> torch.FloatTensor: |
|
if azimuth_deg is None: |
|
assert ( |
|
num_views is not None |
|
), "num_views must be provided if azimuth_deg is None." |
|
azimuth_deg = torch.linspace( |
|
0, 360, num_views + 1, dtype=torch.float32, device=device |
|
)[:-1] |
|
else: |
|
num_views = len(azimuth_deg) |
|
azimuth_deg = list_to_pt(azimuth_deg, dtype=torch.float32, device=device) |
|
elevation_deg = list_to_pt(elevation_deg, dtype=torch.float32, device=device) |
|
camera_distances = list_to_pt(distance, dtype=torch.float32, device=device) |
|
elevation = elevation_deg * math.pi / 180 |
|
azimuth = azimuth_deg * math.pi / 180 |
|
camera_positions = torch.stack( |
|
[ |
|
camera_distances * torch.cos(elevation) * torch.cos(azimuth), |
|
camera_distances * torch.cos(elevation) * torch.sin(azimuth), |
|
camera_distances * torch.sin(elevation), |
|
], |
|
dim=-1, |
|
) |
|
center = torch.zeros_like(camera_positions) |
|
up = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)[None, :].repeat( |
|
num_views, 1 |
|
) |
|
lookat = F.normalize(center - camera_positions, dim=-1) |
|
right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1) |
|
up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1) |
|
c2w3x4 = torch.cat( |
|
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], |
|
dim=-1, |
|
) |
|
c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1) |
|
c2w[:, 3, 3] = 1.0 |
|
return c2w |
|
|
|
|
|
def get_projection_matrix( |
|
fovy_deg: LIST_TYPE, |
|
aspect_wh: float = 1.0, |
|
near: float = 0.1, |
|
far: float = 100.0, |
|
device: Optional[str] = None, |
|
) -> torch.FloatTensor: |
|
fovy_deg = list_to_pt(fovy_deg, dtype=torch.float32, device=device) |
|
batch_size = fovy_deg.shape[0] |
|
fovy = fovy_deg * math.pi / 180 |
|
tan_half_fovy = torch.tan(fovy / 2) |
|
projection_matrix = torch.zeros( |
|
batch_size, 4, 4, dtype=torch.float32, device=device |
|
) |
|
projection_matrix[:, 0, 0] = 1 / (aspect_wh * tan_half_fovy) |
|
projection_matrix[:, 1, 1] = -1 / tan_half_fovy |
|
projection_matrix[:, 2, 2] = -(far + near) / (far - near) |
|
projection_matrix[:, 2, 3] = -2 * far * near / (far - near) |
|
projection_matrix[:, 3, 2] = -1 |
|
return projection_matrix |
|
|
|
|
|
def get_orthogonal_projection_matrix( |
|
batch_size: int, |
|
left: float, |
|
right: float, |
|
bottom: float, |
|
top: float, |
|
near: float = 0.1, |
|
far: float = 100.0, |
|
device: Optional[str] = None, |
|
) -> torch.FloatTensor: |
|
projection_matrix = torch.zeros( |
|
batch_size, 4, 4, dtype=torch.float32, device=device |
|
) |
|
projection_matrix[:, 0, 0] = 2 / (right - left) |
|
projection_matrix[:, 1, 1] = -2 / (top - bottom) |
|
projection_matrix[:, 2, 2] = -2 / (far - near) |
|
projection_matrix[:, 0, 3] = -(right + left) / (right - left) |
|
projection_matrix[:, 1, 3] = -(top + bottom) / (top - bottom) |
|
projection_matrix[:, 2, 3] = -(far + near) / (far - near) |
|
projection_matrix[:, 3, 3] = 1 |
|
return projection_matrix |
|
|
|
|
|
@dataclass |
|
class Camera: |
|
c2w: Optional[torch.FloatTensor] |
|
w2c: torch.FloatTensor |
|
proj_mtx: torch.FloatTensor |
|
mvp_mtx: torch.FloatTensor |
|
cam_pos: Optional[torch.FloatTensor] |
|
|
|
def __getitem__(self, index): |
|
if isinstance(index, int): |
|
sl = slice(index, index + 1) |
|
elif isinstance(index, slice): |
|
sl = index |
|
else: |
|
raise NotImplementedError |
|
|
|
return Camera( |
|
c2w=self.c2w[sl] if self.c2w is not None else None, |
|
w2c=self.w2c[sl], |
|
proj_mtx=self.proj_mtx[sl], |
|
mvp_mtx=self.mvp_mtx[sl], |
|
cam_pos=self.cam_pos[sl] if self.cam_pos is not None else None, |
|
) |
|
|
|
def to(self, device: Optional[str] = None): |
|
if self.c2w is not None: |
|
self.c2w = self.c2w.to(device) |
|
self.w2c = self.w2c.to(device) |
|
self.proj_mtx = self.proj_mtx.to(device) |
|
self.mvp_mtx = self.mvp_mtx.to(device) |
|
if self.cam_pos is not None: |
|
self.cam_pos = self.cam_pos.to(device) |
|
|
|
def __len__(self): |
|
return self.c2w.shape[0] |
|
|
|
|
|
def get_camera( |
|
elevation_deg: Optional[LIST_TYPE] = None, |
|
distance: Optional[LIST_TYPE] = None, |
|
fovy_deg: Optional[LIST_TYPE] = None, |
|
azimuth_deg: Optional[LIST_TYPE] = None, |
|
num_views: Optional[int] = 1, |
|
c2w: Optional[torch.FloatTensor] = None, |
|
w2c: Optional[torch.FloatTensor] = None, |
|
proj_mtx: Optional[torch.FloatTensor] = None, |
|
aspect_wh: float = 1.0, |
|
near: float = 0.1, |
|
far: float = 100.0, |
|
device: Optional[str] = None, |
|
): |
|
if w2c is None: |
|
if c2w is None: |
|
c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device) |
|
camera_positions = c2w[:, :3, 3] |
|
w2c = torch.linalg.inv(c2w) |
|
else: |
|
camera_positions = None |
|
c2w = None |
|
if proj_mtx is None: |
|
proj_mtx = get_projection_matrix( |
|
fovy_deg, aspect_wh=aspect_wh, near=near, far=far, device=device |
|
) |
|
mvp_mtx = proj_mtx @ w2c |
|
return Camera( |
|
c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions |
|
) |
|
|
|
|
|
def get_orthogonal_camera( |
|
elevation_deg: LIST_TYPE, |
|
distance: LIST_TYPE, |
|
left: float, |
|
right: float, |
|
bottom: float, |
|
top: float, |
|
azimuth_deg: Optional[LIST_TYPE] = None, |
|
num_views: Optional[int] = 1, |
|
near: float = 0.1, |
|
far: float = 100.0, |
|
device: Optional[str] = None, |
|
): |
|
c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device) |
|
camera_positions = c2w[:, :3, 3] |
|
w2c = torch.linalg.inv(c2w) |
|
proj_mtx = get_orthogonal_projection_matrix( |
|
batch_size=c2w.shape[0], |
|
left=left, |
|
right=right, |
|
bottom=bottom, |
|
top=top, |
|
near=near, |
|
far=far, |
|
device=device, |
|
) |
|
mvp_mtx = proj_mtx @ w2c |
|
return Camera( |
|
c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions |
|
) |
|
|