from functools import partial import numpy as np import torch from torch import nn import torch.nn.functional as F class GeoConverter(nn.Module): def __init__(self, curve_length=4, bev_only=False, dataset_config=dict()): super().__init__() self.curve_length = curve_length self.coord_dim = 3 if not bev_only else 2 self.convert_fn = self.batch_range2bev if bev_only else self.batch_range2xyz fov = dataset_config.fov self.fov_up = fov[0] / 180.0 * np.pi # field of view up in rad self.fov_down = fov[1] / 180.0 * np.pi # field of view down in rad self.fov_range = abs(self.fov_down) + abs(self.fov_up) # get field of view total in rad self.depth_scale = dataset_config.depth_scale self.depth_min, self.depth_max = dataset_config.depth_range self.log_scale = dataset_config.log_scale self.size = dataset_config['size'] self.register_conversion() def register_conversion(self): scan_x, scan_y = np.meshgrid(np.arange(self.size[1]), np.arange(self.size[0])) scan_x = scan_x.astype(np.float64) / self.size[1] scan_y = scan_y.astype(np.float64) / self.size[0] yaw = (np.pi * (scan_x * 2 - 1)) pitch = ((1.0 - scan_y) * self.fov_range - abs(self.fov_down)) to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer('cos_yaw', torch.cos(to_torch(yaw))) self.register_buffer('sin_yaw', torch.sin(to_torch(yaw))) self.register_buffer('cos_pitch', torch.cos(to_torch(pitch))) self.register_buffer('sin_pitch', torch.sin(to_torch(pitch))) def batch_range2xyz(self, imgs): batch_depth = (imgs * 0.5 + 0.5) * self.depth_scale if self.log_scale: batch_depth = torch.exp2(batch_depth) - 1 batch_depth = batch_depth.clamp(self.depth_min, self.depth_max) batch_x = self.cos_yaw * self.cos_pitch * batch_depth batch_y = -self.sin_yaw * self.cos_pitch * batch_depth batch_z = self.sin_pitch * batch_depth batch_xyz = torch.cat([batch_x, batch_y, batch_z], dim=1) return batch_xyz def batch_range2bev(self, imgs): batch_depth = (imgs * 0.5 + 0.5) * self.depth_scale if self.log_scale: batch_depth = torch.exp2(batch_depth) - 1 batch_depth = batch_depth.clamp(self.depth_min, self.depth_max) batch_x = self.cos_yaw * self.cos_pitch * batch_depth batch_y = -self.sin_yaw * self.cos_pitch * batch_depth batch_bev = torch.cat([batch_x, batch_y], dim=1) return batch_bev def curve_compress(self, batch_coord): compressed_batch_coord = F.avg_pool2d(batch_coord, (1, self.curve_length)) return compressed_batch_coord def forward(self, input): input = input / 2. + .5 # [-1, 1] -> [0, 1] input_coord = self.convert_fn(input) if self.curve_length > 1: input_coord = self.curve_compress(input_coord) return input_coord