import argparse import json import logging import os from os.path import join import cv2 import matplotlib as mpl import matplotlib.cm as cm import numpy as np import torch import matplotlib.pyplot as plt import gradio as gr import cv2 import matplotlib import numpy as np import os from PIL import Image import torch from depth_anything_v2.dpt import DepthAnythingV2 from utils_spike import spikes_to_bsf, load_vidar_dat from recon.recon_bsf.bsf_utils import InputPadder from recon.recon_bsf.bsf.bsf import BSF from collections import OrderedDict DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' def save_vidar_dat(save_path, SpikeSeq, filpud=True, delete_if_exists=True): if delete_if_exists: if os.path.exists(save_path): os.remove(save_path) sfn, h, w = SpikeSeq.shape assert (h * w) % 8 == 0 base = np.power(2, np.linspace(0, 7, 8)) fid = open(save_path, 'ab') for img_id in range(sfn): if filpud: spike = np.flipud(SpikeSeq[img_id, :, :]) else: spike = SpikeSeq[img_id, :, :] spike = spike.flatten() spike = spike.reshape([int(h*w/8), 8]) data = spike * base data = np.sum(data, axis=1).astype(np.uint8) fid.write(data.tobytes()) fid.close() return def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True): ''' output: (frame_cnt, height, width) {0,1} float32 ''' array = np.fromfile(filename, dtype=np.uint8) len_per_frame = height * width // 8 framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame spikes = [] for i in range(framecnt): compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame] blist = [] for b in range(8): blist.append(np.right_shift(np.bitwise_and( compr_frame, np.left_shift(1, b)), b)) frame_ = np.stack(blist).transpose() frame_ = frame_.reshape((height, width), order='C') if reverse_spike: frame_ = np.flipud(frame_) spikes.append(frame_) return np.array(spikes).astype(np.float32) def RawToSpike(video_seq, h, w, flipud=True): video_seq = np.array(video_seq).astype(np.uint8) img_size = h*w img_num = len(video_seq)//(img_size//8) SpikeMatrix = np.zeros([img_num, h, w], np.uint8) pix_id = np.arange(0,h*w) pix_id = np.reshape(pix_id, (h, w)) comparator = np.left_shift(1, np.mod(pix_id, 8)) byte_id = pix_id // 8 for img_id in np.arange(img_num): id_start = int(img_id)*int(img_size)//8 id_end = int(id_start) + int(img_size)//8 cur_info = video_seq[id_start:id_end] data = cur_info[byte_id] result = np.bitwise_and(data, comparator) if flipud: SpikeMatrix[img_id, :, :] = np.flipud((result == comparator)) else: SpikeMatrix[img_id, :, :] = (result == comparator) return SpikeMatrix def spikes_to_middletfi(spike, middle, window=50): C, H, W = spike.shape lindex, rindex = torch.zeros([H, W]), torch.zeros([H, W]) l, r = middle+1, middle+1 for r in range(middle+1, middle + window+1): l = l - 1 if l>=0: newpos = spike[l, :, :]*(1 - torch.sign(lindex)) distance = l*newpos lindex += distance if r=C: break rindex[rindex == 0] = window + middle lindex[lindex == 0] = middle - window interval = rindex - lindex tfi = 1.0 / interval tfi = tfi.unsqueeze(0) return tfi.float() def spikes_to_tfp(spike, idx, halfwsize): # real size of window == 2*halfwsize+1 spike_ = spike[idx-halfwsize:idx+halfwsize] tfp_img = torch.mean(spike_, axis=0) spike_min, spike_max = torch.min(tfp_img), torch.max(tfp_img) tfp_img = (tfp_img - spike_min) / (spike_max - spike_min) return tfp_img def predict_depth(image): DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' model_configs = { 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} } encoder2name = { 'vits': 'Small', 'vitb': 'Base', 'vitl': 'Large', 'vitg': 'Giant', # we are undergoing company review procedures to release our giant model checkpoint } encoder = 'vits' model_name = encoder2name[encoder] model = DepthAnythingV2(**model_configs[encoder]) filepath = f"checkpoints/depth_anything_v2_{encoder}.pth" state_dict = torch.load(filepath, map_location="cpu") model.load_state_dict(state_dict) model = model.to(DEVICE).eval() return model.infer_image(image) def predict_recon_bsf(spike): DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' bsf_model = BSF().to(DEVICE).eval() bsf_ckpt = torch.load("checkpoints/bsf.pth", weights_only=True, map_location="cpu") new_bsf_ckpt = OrderedDict() for k, v in bsf_ckpt.items(): name = k.replace('module.', '') new_bsf_ckpt[name] = v bsf_model.load_state_dict(new_bsf_ckpt) bsf_padder = InputPadder((1, 1, spike.shape[1], spike.shape[2]), padsize=16) central_index=spike.shape[0]//2 recon_bsf = spikes_to_bsf(spike, bsf_model, bsf_padder, central_index, DEVICE) return recon_bsf if __name__ == "__main__": for i in ["08", "26", "28"]: spike_path = f'C:/Users/lze/Desktop/dat/MDE_Dataset/Outdoor-Spike/seq_{i}.dat' f = open(spike_path, 'rb') spike_seq = f.read() spike_seq = np.frombuffer(spike_seq, 'b') spike = RawToSpike(spike_seq, 250, 400) spike = spike.astype(np.float32) f.close() ###################################################################################### # spike = spike[9800:10200, :, :] # save_vidar_dat(f'o{i}.dat', spike, filpud=False, delete_if_exists=True) ###################################################################################### if i == "08": spike = spike[9800:10200, 15:-16, 89:-92] print(spike.shape) elif i == "26": spike = spike[9800:10200, 18:-18, 87:-99] print(spike.shape) elif i == "28": spike = spike[9800:10200, 13:-13, 88:-88] print(spike.shape) ###################################################################################### cmap = matplotlib.colormaps.get_cmap('plasma') h, w = spike.shape[1:3] recon_bsf = predict_recon_bsf(spike) print(type(recon_bsf), recon_bsf.shape) recon_bsf = recon_bsf.repeat(3,1,1) print(type(recon_bsf), recon_bsf.shape) # recon_bsf = (recon_bsf.permute(1,2,0).numpy()*255.0).astype(np.uint8) th, tw = recon_bsf.shape[1], recon_bsf.shape[2] min_dim = min(th, tw) center_crop = recon_bsf[:, (th - min_dim) // 2:(th + min_dim) // 2, (tw - min_dim) // 2:(tw + min_dim) // 2] recon_bsf = (center_crop.permute(1,2,0).numpy()*255.0).astype(np.uint8) print(type(recon_bsf), recon_bsf.shape) depth = predict_depth(recon_bsf[:, :, ::--1]) depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 depth = depth.astype(np.uint8) colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8) recon_bsf_image_path = f'recon_bsf_{i}.png' colored_depth_image_path = f'colored_depth_{i}.png' # 保存 recon_bsf 图像 plt.imsave(recon_bsf_image_path, recon_bsf) # 保存 colored_depth 图像 plt.imsave(colored_depth_image_path, colored_depth) print(f'保存图像: {recon_bsf_image_path} 和 {colored_depth_image_path}')