FredWe's picture
Upload 36 files
88bbcbf
raw
history blame
1.97 kB
"""This module contains functions for geometry transform and camera projection"""
import torch
import torch.nn as nn
import numpy as np
def euler2rot(euler_angle):
batch_size = euler_angle.shape[0]
theta = euler_angle[:, 0].reshape(-1, 1, 1)
phi = euler_angle[:, 1].reshape(-1, 1, 1)
psi = euler_angle[:, 2].reshape(-1, 1, 1)
one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
zero = torch.zeros(
(batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device
)
rot_x = torch.cat(
(
torch.cat((one, zero, zero), 1),
torch.cat((zero, theta.cos(), theta.sin()), 1),
torch.cat((zero, -theta.sin(), theta.cos()), 1),
),
2,
)
rot_y = torch.cat(
(
torch.cat((phi.cos(), zero, -phi.sin()), 1),
torch.cat((zero, one, zero), 1),
torch.cat((phi.sin(), zero, phi.cos()), 1),
),
2,
)
rot_z = torch.cat(
(
torch.cat((psi.cos(), -psi.sin(), zero), 1),
torch.cat((psi.sin(), psi.cos(), zero), 1),
torch.cat((zero, zero, one), 1),
),
2,
)
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
def rot_trans_geo(geometry, rot, trans):
rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1)
return rott_geo.permute(0, 2, 1)
def euler_trans_geo(geometry, euler, trans):
rot = euler2rot(euler)
return rot_trans_geo(geometry, rot, trans)
def proj_geo(rott_geo, camera_para):
fx = camera_para[:, 0]
fy = camera_para[:, 0]
cx = camera_para[:, 1]
cy = camera_para[:, 2]
X = rott_geo[:, :, 0]
Y = rott_geo[:, :, 1]
Z = rott_geo[:, :, 2]
fxX = fx[:, None] * X
fyY = fy[:, None] * Y
proj_x = -fxX / Z + cx[:, None]
proj_y = fyY / Z + cy[:, None]
return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)