Spaces:
Running
Running
import numpy as np | |
import torch | |
from spkProc.tracking.snn_tracker import SNNTracker | |
import matplotlib.pyplot as plt | |
import cv2 | |
def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True): | |
''' | |
output: <class 'numpy.ndarray'> (frame_cnt, height, width) {0,1} float32 | |
''' | |
array = np.fromfile(filename, dtype=np.uint8) | |
len_per_frame = height * width // 8 | |
framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame | |
spikes = [] | |
for i in range(framecnt): | |
compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame] | |
blist = [] | |
for b in range(8): | |
blist.append(np.right_shift(np.bitwise_and( | |
compr_frame, np.left_shift(1, b)), b)) | |
frame_ = np.stack(blist).transpose() | |
frame_ = frame_.reshape((height, width), order='C') | |
if reverse_spike: | |
frame_ = np.flipud(frame_) | |
spikes.append(frame_) | |
return np.array(spikes).astype(np.float32) | |
def detect_motion(spikes, calibration_frames=200, device=None): | |
""" | |
使用SNN进行运动目标检测 | |
Args: | |
spikes: shape为[frames, height, width]的脉冲数据 | |
calibration_frames: 用于校准的帧数 | |
device: 运行设备(CPU/GPU) | |
Returns: | |
motion_mask: 第calibration_frames帧的运动目标掩码 | |
""" | |
if device is None: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
spike_h, spike_w = spikes.shape[1:] | |
# 初始化SNN跟踪器 | |
spike_tracker = SNNTracker(spike_h, spike_w, device, attention_size=15) | |
# 使用前calibration_frames帧进行校准 | |
calibration_spikes = spikes[:calibration_frames] | |
spike_tracker.calibrate_motion(calibration_spikes, calibration_frames) | |
# 获取第calibration_frames帧的运动检测结果 | |
target_frame = spikes[calibration_frames] | |
target_frame = torch.from_numpy(target_frame).to(device) | |
# target_frame = target_frame.reshape(1, 1, spike_h, spike_w) | |
# 获取运动检测结果 | |
motion_id, motion_vector, _ = spike_tracker.motion_estimator.local_wta(target_frame, calibration_frames) | |
# 生成运动掩码 | |
motion_mask = (motion_id > 0).cpu().numpy() | |
return motion_mask | |
def spikes_to_tfi(spk_seq): | |
n, h, w = spk_seq.shape | |
last_index = np.zeros((1, h, w)) | |
cur_index = np.zeros((1, h, w)) | |
c_frames = np.zeros_like(spk_seq).astype(np.float64) | |
for i in range(n - 1): | |
last_index = cur_index | |
cur_index = spk_seq[i+1,:,:] * (i + 1) + (1 - spk_seq[i+1,:,:]) * last_index | |
c_frames[i,:,:] = cur_index - last_index | |
last_frame = c_frames[n-1:,:] | |
last_frame[last_frame==0] = n | |
c_frames[n-1,:,:] = last_frame | |
last_interval = n * np.ones((1, h, w)) | |
for i in range(n - 2, -1, -1): | |
last_interval = spk_seq[i+1,:,:] * c_frames[i,:,:] + (1 - spk_seq[i+1,:,:]) * last_interval | |
tmp_frame = np.expand_dims(c_frames[i,:,:], 0) | |
tmp_frame[tmp_frame==0] = last_interval[tmp_frame==0] | |
c_frames[i] = tmp_frame | |
return 1.0 / c_frames | |
def detect_object(spikes, calibration_frames=200, device=None): | |
""" | |
使用SNN进行目标检测 | |
Args: | |
spikes: shape为[frames, height, width]的脉冲数据 | |
calibration_frames: 用于校准的帧数 | |
device: 运行设备(CPU/GPU) | |
Returns: | |
object_mask: 第calibration_frames帧的目标掩码 | |
""" | |
if device is None: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
spike_h, spike_w = spikes.shape[1:] | |
# 初始化SNN跟踪器 | |
spike_tracker = SNNTracker(spike_h, spike_w, device, attention_size=15) | |
spike_tracker.object_cluster.K2 = 4 | |
# 使用前calibration_frames帧进行校准 | |
calibration_spikes = spikes[:calibration_frames] | |
spike_tracker.calibrate_motion(calibration_spikes, calibration_frames) | |
# 获取第calibration_frames帧的目标检测结果 | |
target_frame = spikes[calibration_frames: calibration_frames + 200] | |
print(target_frame.shape) | |
# target_frame = target_frame.reshape(1, 1, spike_h, spike_w) | |
# 获取目标检测结果 | |
save_filename = "testtest.avi" | |
mov = cv2.VideoWriter(save_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (400, 250)) | |
spike_tracker.get_results(target_frame, save_filename, mov, save_video=True) | |
mov.release() | |
cv2.destroyAllWindows() | |
return 0 | |
if __name__ == "__main__": | |
height = 250 | |
width = 400 | |
spikes = load_vidar_dat("0.dat", width=width, height=height) | |
for n in range(1,10): | |
tmp_spikes = load_vidar_dat(f"{n}.dat", width=width, height=height) | |
spikes = np.concatenate((spikes, tmp_spikes), axis=0) | |
print(spikes.shape) | |
spikes = spikes[::10] | |
motion_mask = detect_object(spikes, calibration_frames=200) | |
tfi = spikes_to_tfi(spikes) | |
# 保存重建的视频 | |
save_recon_filename = "tfi.avi" | |
recon_mov = cv2.VideoWriter(save_recon_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (width, height)) | |
for frame in tfi: | |
frame_norm = (frame * 255).astype(np.uint8) | |
frame_rgb = cv2.cvtColor(frame_norm, cv2.COLOR_GRAY2BGR) | |
recon_mov.write(frame_rgb) | |
recon_mov.release() | |
# 检测运动目标 | |
# motion_mask = detect_motion(spikes, calibration_frames=200) | |
# print(f"Motion mask shape: {motion_mask.shape}") | |
# print(f"Number of motion pixels: {motion_mask.sum()}") | |
# 可视化运动目标检测结果 | |
# plt.figure(figsize=(10, 5)) | |
# plt.subplot(1, 2, 1) | |
# plt.imshow(spikes[200], cmap='gray') | |
# plt.title("Input frame") | |
# plt.axis('off') | |
# plt.subplot(1, 2, 2) | |
# plt.imshow(motion_mask, cmap='gray') | |
# plt.title("Motion mask") | |
# plt.axis('off') | |
# plt.show() | |
# 计算原始脉冲图和运动掩码之间的差异 | |
# spike_frame = spikes[200] # 获取第200帧脉冲图 | |
# # 计算差异指标 | |
# pixel_diff = np.logical_xor(spike_frame > 0, motion_mask).sum() | |
# total_pixels = height * width | |
# diff_ratio = pixel_diff / total_pixels | |
# print("\n运动检测结果分析:") | |
# print(f"原始脉冲图中的活跃像素数: {(spike_frame > 0).sum()}") | |
# print(f"运动掩码中的运动像素数: {motion_mask.sum()}") | |
# print(f"不一致的像素数: {pixel_diff}") | |
# print(f"像素差异比例: {diff_ratio:.2%}") | |
# # 可视化差异 | |
# plt.figure(figsize=(10, 5)) | |
# plt.subplot(1, 2, 1) | |
# plt.imshow(np.logical_xor(spike_frame > 0, motion_mask), cmap='gray') | |
# plt.title("Difference map (white indicates inconsistency)") | |
# plt.axis('off') | |
# plt.subplot(1, 2, 2) | |
# plt.imshow(spike_frame > 0, cmap='gray', alpha=0.5) | |
# plt.imshow(motion_mask, cmap='Reds', alpha=0.5) | |
# plt.title("Overlay (Red: Motion mask, Gray: Original spikes)") | |
# plt.axis('off') | |
# plt.show() | |