import numpy as np import torch from spkProc.tracking.snn_tracker import SNNTracker import matplotlib.pyplot as plt import cv2 def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True): ''' output: (frame_cnt, height, width) {0,1} float32 ''' array = np.fromfile(filename, dtype=np.uint8) len_per_frame = height * width // 8 framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame spikes = [] for i in range(framecnt): compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame] blist = [] for b in range(8): blist.append(np.right_shift(np.bitwise_and( compr_frame, np.left_shift(1, b)), b)) frame_ = np.stack(blist).transpose() frame_ = frame_.reshape((height, width), order='C') if reverse_spike: frame_ = np.flipud(frame_) spikes.append(frame_) return np.array(spikes).astype(np.float32) def detect_motion(spikes, calibration_frames=200, device=None): """ 使用SNN进行运动目标检测 Args: spikes: shape为[frames, height, width]的脉冲数据 calibration_frames: 用于校准的帧数 device: 运行设备(CPU/GPU) Returns: motion_mask: 第calibration_frames帧的运动目标掩码 """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") spike_h, spike_w = spikes.shape[1:] # 初始化SNN跟踪器 spike_tracker = SNNTracker(spike_h, spike_w, device, attention_size=15) # 使用前calibration_frames帧进行校准 calibration_spikes = spikes[:calibration_frames] spike_tracker.calibrate_motion(calibration_spikes, calibration_frames) # 获取第calibration_frames帧的运动检测结果 target_frame = spikes[calibration_frames] target_frame = torch.from_numpy(target_frame).to(device) # target_frame = target_frame.reshape(1, 1, spike_h, spike_w) # 获取运动检测结果 motion_id, motion_vector, _ = spike_tracker.motion_estimator.local_wta(target_frame, calibration_frames) # 生成运动掩码 motion_mask = (motion_id > 0).cpu().numpy() return motion_mask def spikes_to_tfi(spk_seq): n, h, w = spk_seq.shape last_index = np.zeros((1, h, w)) cur_index = np.zeros((1, h, w)) c_frames = np.zeros_like(spk_seq).astype(np.float64) for i in range(n - 1): last_index = cur_index cur_index = spk_seq[i+1,:,:] * (i + 1) + (1 - spk_seq[i+1,:,:]) * last_index c_frames[i,:,:] = cur_index - last_index last_frame = c_frames[n-1:,:] last_frame[last_frame==0] = n c_frames[n-1,:,:] = last_frame last_interval = n * np.ones((1, h, w)) for i in range(n - 2, -1, -1): last_interval = spk_seq[i+1,:,:] * c_frames[i,:,:] + (1 - spk_seq[i+1,:,:]) * last_interval tmp_frame = np.expand_dims(c_frames[i,:,:], 0) tmp_frame[tmp_frame==0] = last_interval[tmp_frame==0] c_frames[i] = tmp_frame return 1.0 / c_frames def detect_object(spikes, calibration_frames=200, device=None): """ 使用SNN进行目标检测 Args: spikes: shape为[frames, height, width]的脉冲数据 calibration_frames: 用于校准的帧数 device: 运行设备(CPU/GPU) Returns: object_mask: 第calibration_frames帧的目标掩码 """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") spike_h, spike_w = spikes.shape[1:] # 初始化SNN跟踪器 spike_tracker = SNNTracker(spike_h, spike_w, device, attention_size=15) spike_tracker.object_cluster.K2 = 4 # 使用前calibration_frames帧进行校准 calibration_spikes = spikes[:calibration_frames] spike_tracker.calibrate_motion(calibration_spikes, calibration_frames) # 获取第calibration_frames帧的目标检测结果 target_frame = spikes[calibration_frames: calibration_frames + 200] print(target_frame.shape) # target_frame = target_frame.reshape(1, 1, spike_h, spike_w) # 获取目标检测结果 save_filename = "testtest.avi" mov = cv2.VideoWriter(save_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (400, 250)) spike_tracker.get_results(target_frame, save_filename, mov, save_video=True) mov.release() cv2.destroyAllWindows() return 0 if __name__ == "__main__": height = 250 width = 400 spikes = load_vidar_dat("0.dat", width=width, height=height) for n in range(1,10): tmp_spikes = load_vidar_dat(f"{n}.dat", width=width, height=height) spikes = np.concatenate((spikes, tmp_spikes), axis=0) print(spikes.shape) spikes = spikes[::10] motion_mask = detect_object(spikes, calibration_frames=200) tfi = spikes_to_tfi(spikes) # 保存重建的视频 save_recon_filename = "tfi.avi" recon_mov = cv2.VideoWriter(save_recon_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (width, height)) for frame in tfi: frame_norm = (frame * 255).astype(np.uint8) frame_rgb = cv2.cvtColor(frame_norm, cv2.COLOR_GRAY2BGR) recon_mov.write(frame_rgb) recon_mov.release() # 检测运动目标 # motion_mask = detect_motion(spikes, calibration_frames=200) # print(f"Motion mask shape: {motion_mask.shape}") # print(f"Number of motion pixels: {motion_mask.sum()}") # 可视化运动目标检测结果 # plt.figure(figsize=(10, 5)) # plt.subplot(1, 2, 1) # plt.imshow(spikes[200], cmap='gray') # plt.title("Input frame") # plt.axis('off') # plt.subplot(1, 2, 2) # plt.imshow(motion_mask, cmap='gray') # plt.title("Motion mask") # plt.axis('off') # plt.show() # 计算原始脉冲图和运动掩码之间的差异 # spike_frame = spikes[200] # 获取第200帧脉冲图 # # 计算差异指标 # pixel_diff = np.logical_xor(spike_frame > 0, motion_mask).sum() # total_pixels = height * width # diff_ratio = pixel_diff / total_pixels # print("\n运动检测结果分析:") # print(f"原始脉冲图中的活跃像素数: {(spike_frame > 0).sum()}") # print(f"运动掩码中的运动像素数: {motion_mask.sum()}") # print(f"不一致的像素数: {pixel_diff}") # print(f"像素差异比例: {diff_ratio:.2%}") # # 可视化差异 # plt.figure(figsize=(10, 5)) # plt.subplot(1, 2, 1) # plt.imshow(np.logical_xor(spike_frame > 0, motion_mask), cmap='gray') # plt.title("Difference map (white indicates inconsistency)") # plt.axis('off') # plt.subplot(1, 2, 2) # plt.imshow(spike_frame > 0, cmap='gray', alpha=0.5) # plt.imshow(motion_mask, cmap='Reds', alpha=0.5) # plt.title("Overlay (Red: Motion mask, Gray: Original spikes)") # plt.axis('off') # plt.show()