Spaces:
Build error
Build error
import torch.nn | |
import torch | |
import torch.nn as nn | |
import models.modules as modules | |
import numpy as np | |
from models.base_model import BaseModule | |
from models.modules.horizon_net_feature_extractor import HorizonNetFeatureExtractor | |
from models.modules.patch_feature_extractor import PatchFeatureExtractor | |
from utils.conversion import uv2depth, get_u, lonlat2depth, get_lon, lonlat2uv | |
from utils.height import calc_ceil_ratio | |
from utils.misc import tensor2np | |
class LGT_Net(BaseModule): | |
def __init__(self, ckpt_dir=None, backbone='resnet50', dropout=0.0, output_name='LGT', | |
decoder_name='Transformer', win_size=8, depth=6, | |
ape=None, rpe=None, corner_heat_map=False, rpe_pos=1): | |
super().__init__(ckpt_dir) | |
self.patch_num = 256 | |
self.patch_dim = 1024 | |
self.decoder_name = decoder_name | |
self.output_name = output_name | |
self.corner_heat_map = corner_heat_map | |
self.dropout_d = dropout | |
if backbone == 'patch': | |
self.feature_extractor = PatchFeatureExtractor(patch_num=self.patch_num, input_shape=[3, 512, 1024]) | |
else: | |
# feature extractor | |
self.feature_extractor = HorizonNetFeatureExtractor(backbone) | |
if 'Transformer' in self.decoder_name: | |
# transformer encoder | |
transformer_dim = self.patch_dim | |
transformer_layers = depth | |
transformer_heads = 8 | |
transformer_head_dim = transformer_dim // transformer_heads | |
transformer_ff_dim = 2048 | |
rpe = None if rpe == 'None' else rpe | |
self.transformer = getattr(modules, decoder_name)(dim=transformer_dim, depth=transformer_layers, | |
heads=transformer_heads, dim_head=transformer_head_dim, | |
mlp_dim=transformer_ff_dim, win_size=win_size, | |
dropout=self.dropout_d, patch_num=self.patch_num, | |
ape=ape, rpe=rpe, rpe_pos=rpe_pos) | |
elif self.decoder_name == 'LSTM': | |
self.bi_rnn = nn.LSTM(input_size=self.feature_extractor.c_last, | |
hidden_size=self.patch_dim // 2, | |
num_layers=2, | |
dropout=self.dropout_d, | |
batch_first=False, | |
bidirectional=True) | |
self.drop_out = nn.Dropout(self.dropout_d) | |
else: | |
raise NotImplementedError("Only support *Transformer and LSTM") | |
if self.output_name == 'LGT': | |
# omnidirectional-geometry aware output | |
self.linear_depth_output = nn.Linear(in_features=self.patch_dim, out_features=1) | |
self.linear_ratio = nn.Linear(in_features=self.patch_dim, out_features=1) | |
self.linear_ratio_output = nn.Linear(in_features=self.patch_num, out_features=1) | |
elif self.output_name == 'LED' or self.output_name == 'Horizon': | |
# horizon-depth or latitude output | |
self.linear = nn.Linear(in_features=self.patch_dim, out_features=2) | |
else: | |
raise NotImplementedError("Unknown output") | |
if self.corner_heat_map: | |
# corners heat map output | |
self.linear_corner_heat_map_output = nn.Linear(in_features=self.patch_dim, out_features=1) | |
self.name = f"{self.decoder_name}_{self.output_name}_Net" | |
def lgt_output(self, x): | |
""" | |
:param x: [ b, 256(patch_num), 1024(d)] | |
:return: { | |
'depth': [b, 256(patch_num & d)] | |
'ratio': [b, 1(d)] | |
} | |
""" | |
depth = self.linear_depth_output(x) # [b, 256(patch_num), 1(d)] | |
depth = depth.view(-1, self.patch_num) # [b, 256(patch_num & d)] | |
# ratio represent room height | |
ratio = self.linear_ratio(x) # [b, 256(patch_num), 1(d)] | |
ratio = ratio.view(-1, self.patch_num) # [b, 256(patch_num & d)] | |
ratio = self.linear_ratio_output(ratio) # [b, 1(d)] | |
output = { | |
'depth': depth, | |
'ratio': ratio | |
} | |
return output | |
def led_output(self, x): | |
""" | |
:param x: [ b, 256(patch_num), 1024(d)] | |
:return: { | |
'depth': [b, 256(patch_num)] | |
'ceil_depth': [b, 256(patch_num)] | |
'ratio': [b, 1(d)] | |
} | |
""" | |
bon = self.linear(x) # [b, 256(patch_num), 2(d)] | |
bon = bon.permute(0, 2, 1) # [b, 2(d), 256(patch_num)] | |
bon = torch.sigmoid(bon) | |
ceil_v = bon[:, 0, :] * -0.5 + 0.5 # [b, 256(patch_num)] | |
floor_v = bon[:, 1, :] * 0.5 + 0.5 # [b, 256(patch_num)] | |
u = get_u(w=self.patch_num, is_np=False, b=ceil_v.shape[0]).to(ceil_v.device) | |
ceil_boundary = torch.stack((u, ceil_v), axis=-1) # [b, 256(patch_num), 2] | |
floor_boundary = torch.stack((u, floor_v), axis=-1) # [b, 256(patch_num), 2] | |
output = { | |
'depth': uv2depth(floor_boundary), # [b, 256(patch_num)] | |
'ceil_depth': uv2depth(ceil_boundary), # [b, 256(patch_num)] | |
} | |
# print(output['depth'].mean()) | |
if not self.training: | |
# [b, 1(d)] | |
output['ratio'] = calc_ceil_ratio([tensor2np(ceil_boundary), tensor2np(floor_boundary)], mode='lsq').reshape(-1, 1) | |
return output | |
def horizon_output(self, x): | |
""" | |
:param x: [ b, 256(patch_num), 1024(d)] | |
:return: { | |
'floor_boundary': [b, 256(patch_num)] | |
'ceil_boundary': [b, 256(patch_num)] | |
} | |
""" | |
bon = self.linear(x) # [b, 256(patch_num), 2(d)] | |
bon = bon.permute(0, 2, 1) # [b, 2(d), 256(patch_num)] | |
output = { | |
'boundary': bon | |
} | |
if not self.training: | |
lon = get_lon(w=self.patch_num, is_np=False, b=bon.shape[0]).to(bon.device) | |
floor_lat = torch.clip(bon[:, 0, :], 1e-4, np.pi / 2) | |
ceil_lat = torch.clip(bon[:, 1, :], -np.pi / 2, -1e-4) | |
floor_lonlat = torch.stack((lon, floor_lat), axis=-1) # [b, 256(patch_num), 2] | |
ceil_lonlat = torch.stack((lon, ceil_lat), axis=-1) # [b, 256(patch_num), 2] | |
output['depth'] = lonlat2depth(floor_lonlat) | |
output['ratio'] = calc_ceil_ratio([tensor2np(lonlat2uv(ceil_lonlat)), | |
tensor2np(lonlat2uv(floor_lonlat))], mode='mean').reshape(-1, 1) | |
return output | |
def forward(self, x): | |
""" | |
:param x: [b, 3(d), 512(h), 1024(w)] | |
:return: { | |
'depth': [b, 256(patch_num & d)] | |
'ratio': [b, 1(d)] | |
} | |
""" | |
# feature extractor | |
x = self.feature_extractor(x) # [b 1024(d) 256(w)] | |
if 'Transformer' in self.decoder_name: | |
# transformer decoder | |
x = x.permute(0, 2, 1) # [b 256(patch_num) 1024(d)] | |
x = self.transformer(x) # [b 256(patch_num) 1024(d)] | |
elif self.decoder_name == 'LSTM': | |
# lstm decoder | |
x = x.permute(2, 0, 1) # [256(patch_num), b, 1024(d)] | |
self.bi_rnn.flatten_parameters() | |
x, _ = self.bi_rnn(x) # [256(patch_num & seq_len), b, 1024(d)] | |
x = x.permute(1, 0, 2) # [b, 256(patch_num), 1024(d)] | |
x = self.drop_out(x) | |
output = None | |
if self.output_name == 'LGT': | |
# plt output | |
output = self.lgt_output(x) | |
elif self.output_name == 'LED': | |
# led output | |
output = self.led_output(x) | |
elif self.output_name == 'Horizon': | |
# led output | |
output = self.horizon_output(x) | |
if self.corner_heat_map: | |
corner_heat_map = self.linear_corner_heat_map_output(x) # [b, 256(patch_num), 1] | |
corner_heat_map = corner_heat_map.view(-1, self.patch_num) | |
corner_heat_map = torch.sigmoid(corner_heat_map) | |
output['corner_heat_map'] = corner_heat_map | |
return output | |
if __name__ == '__main__': | |
from PIL import Image | |
import numpy as np | |
from models.other.init_env import init_env | |
init_env(0, deterministic=True) | |
net = LGT_Net() | |
total = sum(p.numel() for p in net.parameters()) | |
trainable = sum(p.numel() for p in net.parameters() if p.requires_grad) | |
print('parameter total:{:,}, trainable:{:,}'.format(total, trainable)) | |
img = np.array(Image.open("../src/demo.png")).transpose((2, 0, 1)) | |
input = torch.Tensor([img]) # 1 3 512 1024 | |
output = net(input) | |
print(output['depth'].shape) # 1 256 | |
print(output['ratio'].shape) # 1 1 | |