test_embedding_shape / snnTracker /test_motion_detection.py
zzzzzeee's picture
Upload 28 files
9fa5305 verified
import numpy as np
import torch
from spkProc.tracking.snn_tracker import SNNTracker
import matplotlib.pyplot as plt
import cv2
def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True):
'''
output: <class 'numpy.ndarray'> (frame_cnt, height, width) {0,1} float32
'''
array = np.fromfile(filename, dtype=np.uint8)
len_per_frame = height * width // 8
framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame
spikes = []
for i in range(framecnt):
compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame]
blist = []
for b in range(8):
blist.append(np.right_shift(np.bitwise_and(
compr_frame, np.left_shift(1, b)), b))
frame_ = np.stack(blist).transpose()
frame_ = frame_.reshape((height, width), order='C')
if reverse_spike:
frame_ = np.flipud(frame_)
spikes.append(frame_)
return np.array(spikes).astype(np.float32)
def detect_motion(spikes, calibration_frames=200, device=None):
"""
使用SNN进行运动目标检测
Args:
spikes: shape为[frames, height, width]的脉冲数据
calibration_frames: 用于校准的帧数
device: 运行设备(CPU/GPU)
Returns:
motion_mask: 第calibration_frames帧的运动目标掩码
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spike_h, spike_w = spikes.shape[1:]
# 初始化SNN跟踪器
spike_tracker = SNNTracker(spike_h, spike_w, device, attention_size=15)
# 使用前calibration_frames帧进行校准
calibration_spikes = spikes[:calibration_frames]
spike_tracker.calibrate_motion(calibration_spikes, calibration_frames)
# 获取第calibration_frames帧的运动检测结果
target_frame = spikes[calibration_frames]
target_frame = torch.from_numpy(target_frame).to(device)
# target_frame = target_frame.reshape(1, 1, spike_h, spike_w)
# 获取运动检测结果
motion_id, motion_vector, _ = spike_tracker.motion_estimator.local_wta(target_frame, calibration_frames)
# 生成运动掩码
motion_mask = (motion_id > 0).cpu().numpy()
return motion_mask
def spikes_to_tfi(spk_seq):
n, h, w = spk_seq.shape
last_index = np.zeros((1, h, w))
cur_index = np.zeros((1, h, w))
c_frames = np.zeros_like(spk_seq).astype(np.float64)
for i in range(n - 1):
last_index = cur_index
cur_index = spk_seq[i+1,:,:] * (i + 1) + (1 - spk_seq[i+1,:,:]) * last_index
c_frames[i,:,:] = cur_index - last_index
last_frame = c_frames[n-1:,:]
last_frame[last_frame==0] = n
c_frames[n-1,:,:] = last_frame
last_interval = n * np.ones((1, h, w))
for i in range(n - 2, -1, -1):
last_interval = spk_seq[i+1,:,:] * c_frames[i,:,:] + (1 - spk_seq[i+1,:,:]) * last_interval
tmp_frame = np.expand_dims(c_frames[i,:,:], 0)
tmp_frame[tmp_frame==0] = last_interval[tmp_frame==0]
c_frames[i] = tmp_frame
return 1.0 / c_frames
def detect_object(spikes, calibration_frames=200, device=None):
"""
使用SNN进行目标检测
Args:
spikes: shape为[frames, height, width]的脉冲数据
calibration_frames: 用于校准的帧数
device: 运行设备(CPU/GPU)
Returns:
object_mask: 第calibration_frames帧的目标掩码
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spike_h, spike_w = spikes.shape[1:]
# 初始化SNN跟踪器
spike_tracker = SNNTracker(spike_h, spike_w, device, attention_size=15)
spike_tracker.object_cluster.K2 = 4
# 使用前calibration_frames帧进行校准
calibration_spikes = spikes[:calibration_frames]
spike_tracker.calibrate_motion(calibration_spikes, calibration_frames)
# 获取第calibration_frames帧的目标检测结果
target_frame = spikes[calibration_frames: calibration_frames + 200]
print(target_frame.shape)
# target_frame = target_frame.reshape(1, 1, spike_h, spike_w)
# 获取目标检测结果
save_filename = "testtest.avi"
mov = cv2.VideoWriter(save_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (400, 250))
spike_tracker.get_results(target_frame, save_filename, mov, save_video=True)
mov.release()
cv2.destroyAllWindows()
return 0
if __name__ == "__main__":
height = 250
width = 400
spikes = load_vidar_dat("0.dat", width=width, height=height)
for n in range(1,10):
tmp_spikes = load_vidar_dat(f"{n}.dat", width=width, height=height)
spikes = np.concatenate((spikes, tmp_spikes), axis=0)
print(spikes.shape)
spikes = spikes[::10]
motion_mask = detect_object(spikes, calibration_frames=200)
tfi = spikes_to_tfi(spikes)
# 保存重建的视频
save_recon_filename = "tfi.avi"
recon_mov = cv2.VideoWriter(save_recon_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (width, height))
for frame in tfi:
frame_norm = (frame * 255).astype(np.uint8)
frame_rgb = cv2.cvtColor(frame_norm, cv2.COLOR_GRAY2BGR)
recon_mov.write(frame_rgb)
recon_mov.release()
# 检测运动目标
# motion_mask = detect_motion(spikes, calibration_frames=200)
# print(f"Motion mask shape: {motion_mask.shape}")
# print(f"Number of motion pixels: {motion_mask.sum()}")
# 可视化运动目标检测结果
# plt.figure(figsize=(10, 5))
# plt.subplot(1, 2, 1)
# plt.imshow(spikes[200], cmap='gray')
# plt.title("Input frame")
# plt.axis('off')
# plt.subplot(1, 2, 2)
# plt.imshow(motion_mask, cmap='gray')
# plt.title("Motion mask")
# plt.axis('off')
# plt.show()
# 计算原始脉冲图和运动掩码之间的差异
# spike_frame = spikes[200] # 获取第200帧脉冲图
# # 计算差异指标
# pixel_diff = np.logical_xor(spike_frame > 0, motion_mask).sum()
# total_pixels = height * width
# diff_ratio = pixel_diff / total_pixels
# print("\n运动检测结果分析:")
# print(f"原始脉冲图中的活跃像素数: {(spike_frame > 0).sum()}")
# print(f"运动掩码中的运动像素数: {motion_mask.sum()}")
# print(f"不一致的像素数: {pixel_diff}")
# print(f"像素差异比例: {diff_ratio:.2%}")
# # 可视化差异
# plt.figure(figsize=(10, 5))
# plt.subplot(1, 2, 1)
# plt.imshow(np.logical_xor(spike_frame > 0, motion_mask), cmap='gray')
# plt.title("Difference map (white indicates inconsistency)")
# plt.axis('off')
# plt.subplot(1, 2, 2)
# plt.imshow(spike_frame > 0, cmap='gray', alpha=0.5)
# plt.imshow(motion_mask, cmap='Reds', alpha=0.5)
# plt.title("Overlay (Red: Motion mask, Gray: Original spikes)")
# plt.axis('off')
# plt.show()