Spaces:
Running
Running
import hashlib | |
import os | |
import requests | |
import torch | |
import torch.nn as nn | |
from tqdm import tqdm | |
from . import l1, l2 | |
from ...utils.model_utils import build_model | |
URL_MAP = { | |
} | |
CKPT_MAP = { | |
} | |
MD5_MAP = { | |
} | |
PERCEPTUAL_TYPE = { | |
'rangenet_full': [('enc_0', 32), ('enc_1', 64), ('enc_2', 128), ('enc_3', 256), ('enc_4', 512), ('enc_5', 1024), | |
('dec_4', 512), ('dec_3', 256), ('dec_2', 128), ('dec_1', 64), ('dec_0', 32)], | |
'rangenet_enc': [('enc_0', 32), ('enc_1', 64), ('enc_2', 128), ('enc_3', 256), ('enc_4', 512), ('enc_5', 1024)], | |
'rangenet_dec': [('dec_4', 512), ('dec_3', 256), ('dec_2', 128), ('dec_1', 64), ('dec_0', 32)], | |
'rangenet_final': [('dec_0', 32)] | |
} | |
def download(url, local_path, chunk_size=1024): | |
os.makedirs(os.path.split(local_path)[0], exist_ok=True) | |
with requests.get(url, stream=True) as r: | |
total_size = int(r.headers.get("content-length", 0)) | |
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: | |
with open(local_path, "wb") as f: | |
for data in r.iter_content(chunk_size=chunk_size): | |
if data: | |
f.write(data) | |
pbar.update(chunk_size) | |
def md5_hash(path): | |
with open(path, "rb") as f: | |
content = f.read() | |
return hashlib.md5(content).hexdigest() | |
def get_ckpt_path(name, root, check=False): | |
assert name in URL_MAP | |
path = os.path.join(root, CKPT_MAP[name]) | |
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): | |
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) | |
download(URL_MAP[name], path) | |
md5 = md5_hash(path) | |
assert md5 == MD5_MAP[name], md5 | |
return path | |
class NetLinLayer(nn.Module): | |
""" A single linear layer which does a 1x1 conv """ | |
def __init__(self, chn_in, chn_out=1, use_dropout=False): | |
super(NetLinLayer, self).__init__() | |
layers = [nn.Dropout(), ] if (use_dropout) else [] | |
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] | |
self.model = nn.Sequential(*layers) | |
class PerceptualLoss(nn.Module): | |
def __init__(self, ptype, depth_scale, log_scale=True, use_dropout=True, lpips=False, p_loss='l1'): | |
super().__init__() | |
self.depth_scale = depth_scale | |
self.log_scale = log_scale | |
if p_loss == "l1": | |
self.p_loss = l1 | |
else: | |
self.p_loss = l2 | |
self.chns = PERCEPTUAL_TYPE[ptype] | |
self.return_list = [name for name, _ in self.chns] | |
self.loss_scale = [5.0, 3.39, 2.29, 1.61, 0.895] # predefined based on the loss of each stage after a few epochs (refer ) | |
self.net = build_model('kitti', 'rangenet') | |
self.lin_list = nn.ModuleList([NetLinLayer(ch, use_dropout=use_dropout) for _, ch in self.chns]) if lpips else None | |
for param in self.parameters(): | |
param.requires_grad = False | |
def normalize_tensor(x, eps=1e-10): | |
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) | |
return x / (norm_factor + eps) | |
def spatial_average(x, keepdim=True): | |
return x.mean([2, 3], keepdim=keepdim) | |
def preprocess(self, *inputs): | |
assert len(inputs) == 2, 'input with both depth images and coord images' | |
depth_img, xyz_img = inputs | |
# scale to standard rangenet input | |
depth_img = (depth_img * 0.5 + 0.5) * self.depth_scale | |
if self.log_scale: | |
depth_img = torch.exp2(depth_img) - 1 | |
img = torch.cat([depth_img, xyz_img], 1) | |
return img | |
def forward(self, target, input): | |
in0_input, in1_input = self.preprocess(*input), self.preprocess(*target) | |
outs0, outs1 = self.net(in0_input, return_list=self.return_list), self.net(in1_input, return_list=self.return_list) | |
val_list = [] | |
for i, (name, _) in enumerate(self.chns): | |
feats0, feats1 = self.normalize_tensor(outs0[name].to(in0_input.device)), \ | |
self.normalize_tensor(outs1[name].to(in0_input.device)) | |
diffs = self.p_loss(feats0, feats1) | |
res = self.lin_list[i].model(diffs) if self.lin_list is not None else diffs.mean(1, keepdim=True) | |
res = self.spatial_average(res, keepdim=True) * self.loss_scale[i] | |
val_list.append(res) | |
val = sum(val_list) | |
return val | |