import argparse import json import logging import os from os.path import join import cv2 import matplotlib as mpl import matplotlib.cm as cm import numpy as np import torch import matplotlib.pyplot as plt def save_vidar_dat(save_path, SpikeSeq, filpud=True, delete_if_exists=True): if delete_if_exists: if os.path.exists(save_path): os.remove(save_path) sfn, h, w = SpikeSeq.shape assert (h * w) % 8 == 0 base = np.power(2, np.linspace(0, 7, 8)) fid = open(save_path, 'ab') for img_id in range(sfn): if filpud: spike = np.flipud(SpikeSeq[img_id, :, :]) else: spike = SpikeSeq[img_id, :, :] spike = spike.flatten() spike = spike.reshape([int(h*w/8), 8]) data = spike * base data = np.sum(data, axis=1).astype(np.uint8) fid.write(data.tobytes()) fid.close() return def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True): ''' output: (frame_cnt, height, width) {0,1} float32 ''' array = np.fromfile(filename, dtype=np.uint8) len_per_frame = height * width // 8 framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame spikes = [] for i in range(framecnt): compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame] blist = [] for b in range(8): blist.append(np.right_shift(np.bitwise_and( compr_frame, np.left_shift(1, b)), b)) frame_ = np.stack(blist).transpose() frame_ = frame_.reshape((height, width), order='C') if reverse_spike: frame_ = np.flipud(frame_) spikes.append(frame_) return np.array(spikes).astype(np.float32) def RawToSpike(video_seq, h, w, flipud=True): video_seq = np.array(video_seq).astype(np.uint8) img_size = h*w img_num = len(video_seq)//(img_size//8) SpikeMatrix = np.zeros([img_num, h, w], np.uint8) pix_id = np.arange(0,h*w) pix_id = np.reshape(pix_id, (h, w)) comparator = np.left_shift(1, np.mod(pix_id, 8)) byte_id = pix_id // 8 for img_id in np.arange(img_num): id_start = int(img_id)*int(img_size)//8 id_end = int(id_start) + int(img_size)//8 cur_info = video_seq[id_start:id_end] data = cur_info[byte_id] result = np.bitwise_and(data, comparator) if flipud: SpikeMatrix[img_id, :, :] = np.flipud((result == comparator)) else: SpikeMatrix[img_id, :, :] = (result == comparator) return SpikeMatrix def spikes_to_middletfi(spike, middle, window=50): C, H, W = spike.shape lindex, rindex = torch.zeros([H, W]), torch.zeros([H, W]) l, r = middle+1, middle+1 for r in range(middle+1, middle + window+1): l = l - 1 if l>=0: newpos = spike[l, :, :]*(1 - torch.sign(lindex)) distance = l*newpos lindex += distance if r=C: break rindex[rindex == 0] = window + middle lindex[lindex == 0] = middle - window interval = rindex - lindex tfi = 1.0 / interval tfi = tfi.unsqueeze(0) return tfi.float() def spikes_to_tfp(spike, idx, halfwsize): # real size of window == 2*halfwsize+1 spike_ = spike[idx-halfwsize:idx+halfwsize] tfp_img = torch.mean(spike_, axis=0) spike_min, spike_max = torch.min(tfp_img), torch.max(tfp_img) tfp_img = (tfp_img - spike_min) / (spike_max - spike_min) return tfp_img if __name__ == "__main__": spike_path = 'spike_0000000082.npy' spike_path = 'spike_0000000276.npy' for i in ["08", "26", "28"]: spike_path = f'C:/Users/lze/Desktop/dat/MDE_Dataset/Outdoor-Spike/seq_{i}.dat' f = open(spike_path, 'rb') spike_seq = f.read() spike_seq = np.frombuffer(spike_seq, 'b') spikes = RawToSpike(spike_seq, 250, 400) spikes = spikes.astype(np.float32) spikes = torch.from_numpy(spikes) f.close() if i == "08": spikes = spikes[:, 15:-16, 89:-92] elif i == "26": spikes = spikes[:, 18:-18, 87:-99] elif i == "28": spikes = spikes[:, 13:-13, 88:-88] print(spikes.shape) quit() tfp = spikes_to_tfp(spikes, 10000, 100) print(tfp.shape) frame_to_plot = tfp.numpy() # 将torch张量转换为numpy数组 plt.imshow(frame_to_plot, cmap='gray') # 使用灰度色图显示 plt.savefig(f'{i}.png') # 8 28 26