Last commit not found
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def compute_tri_normal(geometry, tris): | |
tri_1 = tris[:, 0] | |
tri_2 = tris[:, 1] | |
tri_3 = tris[:, 2] | |
vert_1 = torch.index_select(geometry, 1, tri_1) | |
vert_2 = torch.index_select(geometry, 1, tri_2) | |
vert_3 = torch.index_select(geometry, 1, tri_3) | |
nnorm = torch.cross(vert_2-vert_1, vert_3-vert_1, 2) | |
normal = nn.functional.normalize(nnorm) | |
return normal | |
def euler2rot(euler_angle): | |
batch_size = euler_angle.shape[0] | |
theta = euler_angle[:, 0].reshape(-1, 1, 1) | |
phi = euler_angle[:, 1].reshape(-1, 1, 1) | |
psi = euler_angle[:, 2].reshape(-1, 1, 1) | |
one = torch.ones(batch_size, 1, 1).to(euler_angle.device) | |
zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device) | |
rot_x = torch.cat(( | |
torch.cat((one, zero, zero), 1), | |
torch.cat((zero, theta.cos(), theta.sin()), 1), | |
torch.cat((zero, -theta.sin(), theta.cos()), 1), | |
), 2) | |
rot_y = torch.cat(( | |
torch.cat((phi.cos(), zero, -phi.sin()), 1), | |
torch.cat((zero, one, zero), 1), | |
torch.cat((phi.sin(), zero, phi.cos()), 1), | |
), 2) | |
rot_z = torch.cat(( | |
torch.cat((psi.cos(), -psi.sin(), zero), 1), | |
torch.cat((psi.sin(), psi.cos(), zero), 1), | |
torch.cat((zero, zero, one), 1) | |
), 2) | |
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) | |
def rot_trans_pts(geometry, rot, trans): | |
rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans[:, :, None] | |
return rott_geo.permute(0, 2, 1) | |
def cal_lap_loss(tensor_list, weight_list): | |
lap_kernel = torch.Tensor( | |
(-0.5, 1.0, -0.5)).unsqueeze(0).unsqueeze(0).float().to(tensor_list[0].device) | |
loss_lap = 0 | |
for i in range(len(tensor_list)): | |
in_tensor = tensor_list[i] | |
in_tensor = in_tensor.view(-1, 1, in_tensor.shape[-1]) | |
out_tensor = F.conv1d(in_tensor, lap_kernel) | |
loss_lap += torch.mean(out_tensor**2)*weight_list[i] | |
return loss_lap | |
def proj_pts(rott_geo, focal_length, cxy): | |
cx, cy = cxy[0], cxy[1] | |
X = rott_geo[:, :, 0] | |
Y = rott_geo[:, :, 1] | |
Z = rott_geo[:, :, 2] | |
fxX = focal_length*X | |
fyY = focal_length*Y | |
proj_x = -fxX/Z + cx | |
proj_y = fyY/Z + cy | |
return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) | |
def forward_transform(geometry, euler_angle, trans, focal_length, cxy): | |
rot = euler2rot(euler_angle) | |
rott_geo = rot_trans_pts(geometry, rot, trans) | |
proj_geo = proj_pts(rott_geo, focal_length, cxy) | |
return proj_geo | |
def cal_lan_loss(proj_lan, gt_lan): | |
return torch.mean((proj_lan-gt_lan)**2) |