|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
from os.path import join
|
|
import cv2
|
|
import matplotlib as mpl
|
|
import matplotlib.cm as cm
|
|
import numpy as np
|
|
import torch
|
|
import matplotlib.pyplot as plt
|
|
import gradio as gr
|
|
import cv2
|
|
import matplotlib
|
|
import numpy as np
|
|
import os
|
|
from PIL import Image
|
|
import torch
|
|
from depth_anything_v2.dpt import DepthAnythingV2
|
|
from utils_spike import spikes_to_bsf, load_vidar_dat
|
|
from recon.recon_bsf.bsf_utils import InputPadder
|
|
from recon.recon_bsf.bsf.bsf import BSF
|
|
from collections import OrderedDict
|
|
|
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
def save_vidar_dat(save_path, SpikeSeq, filpud=True, delete_if_exists=True):
|
|
if delete_if_exists:
|
|
if os.path.exists(save_path):
|
|
os.remove(save_path)
|
|
sfn, h, w = SpikeSeq.shape
|
|
assert (h * w) % 8 == 0
|
|
base = np.power(2, np.linspace(0, 7, 8))
|
|
fid = open(save_path, 'ab')
|
|
for img_id in range(sfn):
|
|
if filpud:
|
|
spike = np.flipud(SpikeSeq[img_id, :, :])
|
|
else:
|
|
spike = SpikeSeq[img_id, :, :]
|
|
spike = spike.flatten()
|
|
spike = spike.reshape([int(h*w/8), 8])
|
|
data = spike * base
|
|
data = np.sum(data, axis=1).astype(np.uint8)
|
|
fid.write(data.tobytes())
|
|
fid.close()
|
|
return
|
|
|
|
def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True):
|
|
'''
|
|
output: <class 'numpy.ndarray'> (frame_cnt, height, width) {0,1} float32
|
|
'''
|
|
array = np.fromfile(filename, dtype=np.uint8)
|
|
|
|
len_per_frame = height * width // 8
|
|
framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame
|
|
|
|
spikes = []
|
|
for i in range(framecnt):
|
|
compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame]
|
|
blist = []
|
|
for b in range(8):
|
|
blist.append(np.right_shift(np.bitwise_and(
|
|
compr_frame, np.left_shift(1, b)), b))
|
|
|
|
frame_ = np.stack(blist).transpose()
|
|
frame_ = frame_.reshape((height, width), order='C')
|
|
if reverse_spike:
|
|
frame_ = np.flipud(frame_)
|
|
spikes.append(frame_)
|
|
|
|
return np.array(spikes).astype(np.float32)
|
|
|
|
|
|
def RawToSpike(video_seq, h, w, flipud=True):
|
|
|
|
video_seq = np.array(video_seq).astype(np.uint8)
|
|
img_size = h*w
|
|
img_num = len(video_seq)//(img_size//8)
|
|
SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
|
|
pix_id = np.arange(0,h*w)
|
|
pix_id = np.reshape(pix_id, (h, w))
|
|
comparator = np.left_shift(1, np.mod(pix_id, 8))
|
|
byte_id = pix_id // 8
|
|
|
|
for img_id in np.arange(img_num):
|
|
id_start = int(img_id)*int(img_size)//8
|
|
id_end = int(id_start) + int(img_size)//8
|
|
cur_info = video_seq[id_start:id_end]
|
|
data = cur_info[byte_id]
|
|
result = np.bitwise_and(data, comparator)
|
|
if flipud:
|
|
SpikeMatrix[img_id, :, :] = np.flipud((result == comparator))
|
|
else:
|
|
SpikeMatrix[img_id, :, :] = (result == comparator)
|
|
|
|
return SpikeMatrix
|
|
|
|
def spikes_to_middletfi(spike, middle, window=50):
|
|
C, H, W = spike.shape
|
|
lindex, rindex = torch.zeros([H, W]), torch.zeros([H, W])
|
|
l, r = middle+1, middle+1
|
|
for r in range(middle+1, middle + window+1):
|
|
l = l - 1
|
|
if l>=0:
|
|
newpos = spike[l, :, :]*(1 - torch.sign(lindex))
|
|
distance = l*newpos
|
|
lindex += distance
|
|
if r<C:
|
|
newpos = spike[r, :, :]*(1 - torch.sign(rindex))
|
|
distance = r*newpos
|
|
rindex += distance
|
|
if l<0 and r>=C:
|
|
break
|
|
rindex[rindex == 0] = window + middle
|
|
lindex[lindex == 0] = middle - window
|
|
interval = rindex - lindex
|
|
tfi = 1.0 / interval
|
|
tfi = tfi.unsqueeze(0)
|
|
return tfi.float()
|
|
|
|
|
|
def spikes_to_tfp(spike, idx, halfwsize):
|
|
|
|
spike_ = spike[idx-halfwsize:idx+halfwsize]
|
|
tfp_img = torch.mean(spike_, axis=0)
|
|
spike_min, spike_max = torch.min(tfp_img), torch.max(tfp_img)
|
|
tfp_img = (tfp_img - spike_min) / (spike_max - spike_min)
|
|
return tfp_img
|
|
|
|
|
|
|
|
def predict_depth(image):
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
model_configs = {
|
|
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
|
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
|
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
|
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
|
|
}
|
|
encoder2name = {
|
|
'vits': 'Small',
|
|
'vitb': 'Base',
|
|
'vitl': 'Large',
|
|
'vitg': 'Giant',
|
|
}
|
|
encoder = 'vits'
|
|
model_name = encoder2name[encoder]
|
|
model = DepthAnythingV2(**model_configs[encoder])
|
|
filepath = f"checkpoints/depth_anything_v2_{encoder}.pth"
|
|
state_dict = torch.load(filepath, map_location="cpu")
|
|
model.load_state_dict(state_dict)
|
|
model = model.to(DEVICE).eval()
|
|
return model.infer_image(image)
|
|
|
|
|
|
def predict_recon_bsf(spike):
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
bsf_model = BSF().to(DEVICE).eval()
|
|
bsf_ckpt = torch.load("checkpoints/bsf.pth", weights_only=True, map_location="cpu")
|
|
new_bsf_ckpt = OrderedDict()
|
|
for k, v in bsf_ckpt.items():
|
|
name = k.replace('module.', '')
|
|
new_bsf_ckpt[name] = v
|
|
bsf_model.load_state_dict(new_bsf_ckpt)
|
|
bsf_padder = InputPadder((1, 1, spike.shape[1], spike.shape[2]), padsize=16)
|
|
central_index=spike.shape[0]//2
|
|
recon_bsf = spikes_to_bsf(spike, bsf_model, bsf_padder, central_index, DEVICE)
|
|
return recon_bsf
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
for i in ["08", "26", "28"]:
|
|
spike_path = f'C:/Users/lze/Desktop/dat/MDE_Dataset/Outdoor-Spike/seq_{i}.dat'
|
|
f = open(spike_path, 'rb')
|
|
spike_seq = f.read()
|
|
spike_seq = np.frombuffer(spike_seq, 'b')
|
|
spike = RawToSpike(spike_seq, 250, 400)
|
|
spike = spike.astype(np.float32)
|
|
f.close()
|
|
|
|
|
|
|
|
|
|
if i == "08":
|
|
spike = spike[9800:10200, 15:-16, 89:-92]
|
|
print(spike.shape)
|
|
elif i == "26":
|
|
spike = spike[9800:10200, 18:-18, 87:-99]
|
|
print(spike.shape)
|
|
elif i == "28":
|
|
spike = spike[9800:10200, 13:-13, 88:-88]
|
|
print(spike.shape)
|
|
|
|
|
|
cmap = matplotlib.colormaps.get_cmap('plasma')
|
|
h, w = spike.shape[1:3]
|
|
recon_bsf = predict_recon_bsf(spike)
|
|
print(type(recon_bsf), recon_bsf.shape)
|
|
recon_bsf = recon_bsf.repeat(3,1,1)
|
|
print(type(recon_bsf), recon_bsf.shape)
|
|
|
|
th, tw = recon_bsf.shape[1], recon_bsf.shape[2]
|
|
min_dim = min(th, tw)
|
|
center_crop = recon_bsf[:, (th - min_dim) // 2:(th + min_dim) // 2, (tw - min_dim) // 2:(tw + min_dim) // 2]
|
|
recon_bsf = (center_crop.permute(1,2,0).numpy()*255.0).astype(np.uint8)
|
|
print(type(recon_bsf), recon_bsf.shape)
|
|
depth = predict_depth(recon_bsf[:, :, ::--1])
|
|
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
|
depth = depth.astype(np.uint8)
|
|
colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
|
|
|
|
recon_bsf_image_path = f'recon_bsf_{i}.png'
|
|
colored_depth_image_path = f'colored_depth_{i}.png'
|
|
|
|
plt.imsave(recon_bsf_image_path, recon_bsf)
|
|
|
|
plt.imsave(colored_depth_image_path, colored_depth)
|
|
print(f'保存图像: {recon_bsf_image_path} 和 {colored_depth_image_path}')
|
|
|
|
|