Hancy's picture
init
851751e
from functools import partial
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
class GeoConverter(nn.Module):
def __init__(self, curve_length=4, bev_only=False, dataset_config=dict()):
super().__init__()
self.curve_length = curve_length
self.coord_dim = 3 if not bev_only else 2
self.convert_fn = self.batch_range2bev if bev_only else self.batch_range2xyz
fov = dataset_config.fov
self.fov_up = fov[0] / 180.0 * np.pi # field of view up in rad
self.fov_down = fov[1] / 180.0 * np.pi # field of view down in rad
self.fov_range = abs(self.fov_down) + abs(self.fov_up) # get field of view total in rad
self.depth_scale = dataset_config.depth_scale
self.depth_min, self.depth_max = dataset_config.depth_range
self.log_scale = dataset_config.log_scale
self.size = dataset_config['size']
self.register_conversion()
def register_conversion(self):
scan_x, scan_y = np.meshgrid(np.arange(self.size[1]), np.arange(self.size[0]))
scan_x = scan_x.astype(np.float64) / self.size[1]
scan_y = scan_y.astype(np.float64) / self.size[0]
yaw = (np.pi * (scan_x * 2 - 1))
pitch = ((1.0 - scan_y) * self.fov_range - abs(self.fov_down))
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('cos_yaw', torch.cos(to_torch(yaw)))
self.register_buffer('sin_yaw', torch.sin(to_torch(yaw)))
self.register_buffer('cos_pitch', torch.cos(to_torch(pitch)))
self.register_buffer('sin_pitch', torch.sin(to_torch(pitch)))
def batch_range2xyz(self, imgs):
batch_depth = (imgs * 0.5 + 0.5) * self.depth_scale
if self.log_scale:
batch_depth = torch.exp2(batch_depth) - 1
batch_depth = batch_depth.clamp(self.depth_min, self.depth_max)
batch_x = self.cos_yaw * self.cos_pitch * batch_depth
batch_y = -self.sin_yaw * self.cos_pitch * batch_depth
batch_z = self.sin_pitch * batch_depth
batch_xyz = torch.cat([batch_x, batch_y, batch_z], dim=1)
return batch_xyz
def batch_range2bev(self, imgs):
batch_depth = (imgs * 0.5 + 0.5) * self.depth_scale
if self.log_scale:
batch_depth = torch.exp2(batch_depth) - 1
batch_depth = batch_depth.clamp(self.depth_min, self.depth_max)
batch_x = self.cos_yaw * self.cos_pitch * batch_depth
batch_y = -self.sin_yaw * self.cos_pitch * batch_depth
batch_bev = torch.cat([batch_x, batch_y], dim=1)
return batch_bev
def curve_compress(self, batch_coord):
compressed_batch_coord = F.avg_pool2d(batch_coord, (1, self.curve_length))
return compressed_batch_coord
def forward(self, input):
input = input / 2. + .5 # [-1, 1] -> [0, 1]
input_coord = self.convert_fn(input)
if self.curve_length > 1:
input_coord = self.curve_compress(input_coord)
return input_coord