Spaces:
Running
Running
File size: 7,185 Bytes
9fa5305 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import numpy as np
import torch
from spkProc.tracking.snn_tracker import SNNTracker
import matplotlib.pyplot as plt
import cv2
def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True):
'''
output: <class 'numpy.ndarray'> (frame_cnt, height, width) {0,1} float32
'''
array = np.fromfile(filename, dtype=np.uint8)
len_per_frame = height * width // 8
framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame
spikes = []
for i in range(framecnt):
compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame]
blist = []
for b in range(8):
blist.append(np.right_shift(np.bitwise_and(
compr_frame, np.left_shift(1, b)), b))
frame_ = np.stack(blist).transpose()
frame_ = frame_.reshape((height, width), order='C')
if reverse_spike:
frame_ = np.flipud(frame_)
spikes.append(frame_)
return np.array(spikes).astype(np.float32)
def detect_motion(spikes, calibration_frames=200, device=None):
"""
使用SNN进行运动目标检测
Args:
spikes: shape为[frames, height, width]的脉冲数据
calibration_frames: 用于校准的帧数
device: 运行设备(CPU/GPU)
Returns:
motion_mask: 第calibration_frames帧的运动目标掩码
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spike_h, spike_w = spikes.shape[1:]
# 初始化SNN跟踪器
spike_tracker = SNNTracker(spike_h, spike_w, device, attention_size=15)
# 使用前calibration_frames帧进行校准
calibration_spikes = spikes[:calibration_frames]
spike_tracker.calibrate_motion(calibration_spikes, calibration_frames)
# 获取第calibration_frames帧的运动检测结果
target_frame = spikes[calibration_frames]
target_frame = torch.from_numpy(target_frame).to(device)
# target_frame = target_frame.reshape(1, 1, spike_h, spike_w)
# 获取运动检测结果
motion_id, motion_vector, _ = spike_tracker.motion_estimator.local_wta(target_frame, calibration_frames)
# 生成运动掩码
motion_mask = (motion_id > 0).cpu().numpy()
return motion_mask
def spikes_to_tfi(spk_seq):
n, h, w = spk_seq.shape
last_index = np.zeros((1, h, w))
cur_index = np.zeros((1, h, w))
c_frames = np.zeros_like(spk_seq).astype(np.float64)
for i in range(n - 1):
last_index = cur_index
cur_index = spk_seq[i+1,:,:] * (i + 1) + (1 - spk_seq[i+1,:,:]) * last_index
c_frames[i,:,:] = cur_index - last_index
last_frame = c_frames[n-1:,:]
last_frame[last_frame==0] = n
c_frames[n-1,:,:] = last_frame
last_interval = n * np.ones((1, h, w))
for i in range(n - 2, -1, -1):
last_interval = spk_seq[i+1,:,:] * c_frames[i,:,:] + (1 - spk_seq[i+1,:,:]) * last_interval
tmp_frame = np.expand_dims(c_frames[i,:,:], 0)
tmp_frame[tmp_frame==0] = last_interval[tmp_frame==0]
c_frames[i] = tmp_frame
return 1.0 / c_frames
def detect_object(spikes, calibration_frames=200, device=None):
"""
使用SNN进行目标检测
Args:
spikes: shape为[frames, height, width]的脉冲数据
calibration_frames: 用于校准的帧数
device: 运行设备(CPU/GPU)
Returns:
object_mask: 第calibration_frames帧的目标掩码
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spike_h, spike_w = spikes.shape[1:]
# 初始化SNN跟踪器
spike_tracker = SNNTracker(spike_h, spike_w, device, attention_size=15)
spike_tracker.object_cluster.K2 = 4
# 使用前calibration_frames帧进行校准
calibration_spikes = spikes[:calibration_frames]
spike_tracker.calibrate_motion(calibration_spikes, calibration_frames)
# 获取第calibration_frames帧的目标检测结果
target_frame = spikes[calibration_frames: calibration_frames + 200]
print(target_frame.shape)
# target_frame = target_frame.reshape(1, 1, spike_h, spike_w)
# 获取目标检测结果
save_filename = "testtest.avi"
mov = cv2.VideoWriter(save_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (400, 250))
spike_tracker.get_results(target_frame, save_filename, mov, save_video=True)
mov.release()
cv2.destroyAllWindows()
return 0
if __name__ == "__main__":
height = 250
width = 400
spikes = load_vidar_dat("0.dat", width=width, height=height)
for n in range(1,10):
tmp_spikes = load_vidar_dat(f"{n}.dat", width=width, height=height)
spikes = np.concatenate((spikes, tmp_spikes), axis=0)
print(spikes.shape)
spikes = spikes[::10]
motion_mask = detect_object(spikes, calibration_frames=200)
tfi = spikes_to_tfi(spikes)
# 保存重建的视频
save_recon_filename = "tfi.avi"
recon_mov = cv2.VideoWriter(save_recon_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (width, height))
for frame in tfi:
frame_norm = (frame * 255).astype(np.uint8)
frame_rgb = cv2.cvtColor(frame_norm, cv2.COLOR_GRAY2BGR)
recon_mov.write(frame_rgb)
recon_mov.release()
# 检测运动目标
# motion_mask = detect_motion(spikes, calibration_frames=200)
# print(f"Motion mask shape: {motion_mask.shape}")
# print(f"Number of motion pixels: {motion_mask.sum()}")
# 可视化运动目标检测结果
# plt.figure(figsize=(10, 5))
# plt.subplot(1, 2, 1)
# plt.imshow(spikes[200], cmap='gray')
# plt.title("Input frame")
# plt.axis('off')
# plt.subplot(1, 2, 2)
# plt.imshow(motion_mask, cmap='gray')
# plt.title("Motion mask")
# plt.axis('off')
# plt.show()
# 计算原始脉冲图和运动掩码之间的差异
# spike_frame = spikes[200] # 获取第200帧脉冲图
# # 计算差异指标
# pixel_diff = np.logical_xor(spike_frame > 0, motion_mask).sum()
# total_pixels = height * width
# diff_ratio = pixel_diff / total_pixels
# print("\n运动检测结果分析:")
# print(f"原始脉冲图中的活跃像素数: {(spike_frame > 0).sum()}")
# print(f"运动掩码中的运动像素数: {motion_mask.sum()}")
# print(f"不一致的像素数: {pixel_diff}")
# print(f"像素差异比例: {diff_ratio:.2%}")
# # 可视化差异
# plt.figure(figsize=(10, 5))
# plt.subplot(1, 2, 1)
# plt.imshow(np.logical_xor(spike_frame > 0, motion_mask), cmap='gray')
# plt.title("Difference map (white indicates inconsistency)")
# plt.axis('off')
# plt.subplot(1, 2, 2)
# plt.imshow(spike_frame > 0, cmap='gray', alpha=0.5)
# plt.imshow(motion_mask, cmap='Reds', alpha=0.5)
# plt.title("Overlay (Red: Motion mask, Gray: Original spikes)")
# plt.axis('off')
# plt.show()
|