|
import argparse |
|
import json |
|
import logging |
|
import os |
|
from os.path import join |
|
|
|
import cv2 |
|
import matplotlib as mpl |
|
import matplotlib.cm as cm |
|
import numpy as np |
|
import torch |
|
from SpikeT.model.S2DepthNet import S2DepthTransformerUNetConv |
|
from SpikeT.utils.data_augmentation import CenterCrop, RandomCrop |
|
|
|
logging.basicConfig(level=logging.INFO, format='') |
|
|
|
|
|
|
|
def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True): |
|
''' |
|
output: <class 'numpy.ndarray'> (frame_cnt, height, width) {0,1} float32 |
|
''' |
|
array = np.fromfile(filename, dtype=np.uint8) |
|
|
|
len_per_frame = height * width // 8 |
|
framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame |
|
|
|
spikes = [] |
|
for i in range(framecnt): |
|
compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame] |
|
blist = [] |
|
for b in range(8): |
|
blist.append(np.right_shift(np.bitwise_and( |
|
compr_frame, np.left_shift(1, b)), b)) |
|
|
|
frame_ = np.stack(blist).transpose() |
|
frame_ = frame_.reshape((height, width), order='C') |
|
if reverse_spike: |
|
frame_ = np.flipud(frame_) |
|
spikes.append(frame_) |
|
|
|
return np.array(spikes).astype(np.float32) |
|
|
|
|
|
def RawToSpike(video_seq, h, w, flipud=True): |
|
|
|
video_seq = np.array(video_seq).astype(np.uint8) |
|
img_size = h*w |
|
img_num = len(video_seq)//(img_size//8) |
|
SpikeMatrix = np.zeros([img_num, h, w], np.uint8) |
|
pix_id = np.arange(0,h*w) |
|
pix_id = np.reshape(pix_id, (h, w)) |
|
comparator = np.left_shift(1, np.mod(pix_id, 8)) |
|
byte_id = pix_id // 8 |
|
|
|
for img_id in np.arange(img_num): |
|
id_start = int(img_id)*int(img_size)//8 |
|
id_end = int(id_start) + int(img_size)//8 |
|
cur_info = video_seq[id_start:id_end] |
|
data = cur_info[byte_id] |
|
result = np.bitwise_and(data, comparator) |
|
if flipud: |
|
SpikeMatrix[img_id, :, :] = np.flipud((result == comparator)) |
|
else: |
|
SpikeMatrix[img_id, :, :] = (result == comparator) |
|
|
|
return SpikeMatrix |
|
|
|
|
|
def ensure_dir(path): |
|
if not os.path.exists(path): |
|
os.makedirs(path) |
|
|
|
|
|
def make_colormap(img, color_mapper): |
|
color_map_inv = np.ones_like(img[0]) * np.amax(img[0]) - img[0] |
|
color_map_inv = np.nan_to_num(color_map_inv, nan=1) |
|
color_map_inv = color_map_inv / np.amax(color_map_inv) |
|
color_map_inv = np.nan_to_num(color_map_inv) |
|
color_map_inv = color_mapper.to_rgba(color_map_inv) |
|
color_map_inv[:, :, 0:3] = color_map_inv[:, :, 0:3][..., ::-1] |
|
return color_map_inv |
|
|
|
|
|
def main(config, initial_checkpoint, spike_path): |
|
use_phased_arch = config['use_phased_arch'] |
|
|
|
config['model']['gpu'] = config['gpu'] |
|
config['model']['every_x_rgb_frame'] = config['data_loader']['train']['every_x_rgb_frame'] |
|
config['model']['baseline'] = config['data_loader']['train']['baseline'] |
|
config['model']['loss_composition'] = config['trainer']['loss_composition'] |
|
|
|
model = eval(config['arch'])(config['model']) |
|
if initial_checkpoint is not None: |
|
print('Loading initial model weights from: {}'.format(initial_checkpoint)) |
|
checkpoint = torch.load(initial_checkpoint) |
|
print(checkpoint['state_dict']) |
|
model = torch.nn.DataParallel(model).cuda() |
|
if use_phased_arch: |
|
C, (H, W) = config["model"]["num_bins_events"], config["model"]["spatial_resolution"] |
|
dummy_input = torch.Tensor(1, C, H, W) |
|
times = torch.Tensor(1) |
|
_ = model.forward(dummy_input, times=times, prev_states=None) |
|
print(model.state_dict) |
|
model.load_state_dict(checkpoint['state_dict']) |
|
|
|
gpu = torch.device('cuda:' + str(config['gpu'])) |
|
model.to(gpu) |
|
model.eval() |
|
|
|
data_tranfsorm = CenterCrop(224) |
|
|
|
|
|
f = open(spike_path, 'rb') |
|
spike_seq = f.read() |
|
spike_seq = np.frombuffer(spike_seq, 'b') |
|
spikes = RawToSpike(spike_seq, 260, 346) |
|
spikes = spikes.astype(np.float32) |
|
spikes = torch.from_numpy(spikes) |
|
f.close() |
|
|
|
data = data_tranfsorm(spikes) |
|
dT, dH, dW = data.shape |
|
item = {} |
|
item['image'] = data |
|
input = {} |
|
input['image'] = data[None, dT//2-64:dT//2+64] |
|
prev_super_states = {'image': None} |
|
prev_states_lstm = {} |
|
new_predicted_targets, _, _ = model(input, prev_super_states['image'], prev_states_lstm) |
|
|
|
frame = new_predicted_targets['image'][0].detach().cpu().numpy() |
|
color_map_inv = np.ones_like(frame[0]) * np.amax(frame[0]) - frame[0] |
|
color_map_inv = np.nan_to_num(color_map_inv, nan=1) |
|
color_map_inv = color_map_inv / np.amax(color_map_inv) |
|
color_map_inv = np.nan_to_num(color_map_inv) |
|
vmax = np.percentile(color_map_inv, 95) |
|
normalizer = mpl.colors.Normalize(vmin=color_map_inv.min(), vmax=vmax) |
|
color_mapper_overall = cm.ScalarMappable(norm=normalizer, cmap='magma') |
|
|
|
with torch.no_grad(): |
|
f = open(spike_path, 'rb') |
|
spike_seq = f.read() |
|
spike_seq = np.frombuffer(spike_seq, 'b') |
|
spikes = RawToSpike(spike_seq, 260, 346) |
|
spikes = spikes.astype(np.float32) |
|
spikes = torch.from_numpy(spikes) |
|
f.close() |
|
data = data_tranfsorm(spikes) |
|
|
|
print(data.shape) |
|
dT, dH, dW = data.shape |
|
item = {} |
|
item['image'] = data |
|
input = {} |
|
input['image'] = data[None, dT//2-64:dT//2+64] |
|
prev_super_states = {'image': None} |
|
prev_states_lstm = {} |
|
|
|
new_predicted_targets, _, _ = model(input, prev_super_states['image'], prev_states_lstm) |
|
|
|
predict_depth = new_predicted_targets['image'] |
|
print(predict_depth.shape) |
|
predict_depth = predict_depth[0].cpu().numpy() |
|
img = predict_depth |
|
cv2.imwrite(f'{os.path.basename(spike_path).split(".")[0]}_image.png', img[0][:, :, None] * 255.0) |
|
|
|
spikes = data.permute(1,2,0).cpu().numpy() |
|
input_spikes = np.mean(spikes, axis=2).astype(np.float32) |
|
cv2.imwrite(f'{os.path.basename(spike_path).split(".")[0]}_spike.png', input_spikes[:, :, None] * 255.0) |
|
|
|
|
|
color_map = make_colormap(img, color_mapper_overall) |
|
cv2.imwrite(f'{os.path.basename(spike_path).split(".")[0]}_colormap.png', color_map * 255.0) |
|
|
|
|
|
if __name__ == '__main__': |
|
logger = logging.getLogger() |
|
parser = argparse.ArgumentParser( |
|
description='Inference depth map from monocular spike stream') |
|
parser.add_argument('--path_to_model', type=str, |
|
help='path to the model weights', |
|
default='SpikeT/s2d_weights/debug_A100_SpikeTransformerUNetConv_LocalGlobal-Swin3D-T/model_best.pth.tar') |
|
parser.add_argument('--config', type=str, |
|
help='path to config. If not specified, config from model folder is taken', |
|
default=None) |
|
parser.add_argument('--data_folder', type=str, |
|
help='path to folder of data to be tested', |
|
default=None) |
|
args = parser.parse_args() |
|
|
|
if args.config is None: |
|
head_tail = os.path.split(args.path_to_model) |
|
config = json.load(open(os.path.join(head_tail[0], 'config.json'))) |
|
else: |
|
config = json.load(open(args.config)) |
|
|
|
|
|
spike_path = 'driving_outdoor0.dat' |
|
spike_path = 'dense276.npy' |
|
spike_path = 'dense082.npy' |
|
|
|
main(config, args.path_to_model, f'asset/{spike_path}') |
|
|