motionmask / test_stpfilter.py
zzzzzeee's picture
Create test_stpfilter.py
9ce00e9 verified
import torch
from utils_spike import STPFilter
def generate_motion_mask(spikes, spike_h, spike_w, device, diff_time=1, **STPargs):
"""
从脉冲数据中生成运动掩码,使用 STPFilter 类。
参数:
spikes (torch.Tensor):脉冲数据,形状为 (time_steps, spike_h, spike_w),值为 0 或 1
spike_h (int):脉冲数据的高度
spike_w (int):脉冲数据的宽度
device (torch.device):计算设备(如 torch.device('cuda') 或 torch.device('cpu'))
diff_time (int):记录历史动态的窗口长度,默认为 1
**STPargs:STPFilter 初始化所需的额外参数(如 u0, D, F, f, time_unit 等)
返回:
motion_mask (torch.Tensor):运动掩码,形状为 (time_steps, spike_h, spike_w),值为 0 或 1
"""
# 初始化 STPFilter 实例
stp_filter = STPFilter(spike_h, spike_w, device, diff_time, **STPargs)
# 获取时间步数
time_steps = spikes.shape[0]
# 预分配运动掩码张量,提高效率
motion_mask = torch.zeros(time_steps, spike_h, spike_w, device=device)
# 逐时间步处理脉冲数据
for t in range(time_steps):
cur_spikes = spikes[t] # 当前时间步的脉冲,形状为 (spike_h, spike_w)
stp_filter.update_dynamics(t, cur_spikes) # 更新动态并生成 filter_spk
motion_mask[t] = stp_filter.filter_spk # 记录当前时间步的运动掩码
return motion_mask
# 示例用法
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
spikes = torch.randint(0, 2, (100, 64, 64), dtype=torch.float, device=device) # 随机脉冲数据
spike_h, spike_w = 64, 64
# 使用默认参数
motion_mask = generate_motion_mask(spikes, spike_h, spike_w, device)
# 使用自定义参数(参考图片描述)
STPargs = {'u0': 0.1, 'D': 200/2000, 'F': 400/2000, 'f': 0.11, 'time_unit': 2000}
motion_mask = generate_motion_mask(spikes, spike_h, spike_w, device, diff_time=5, **STPargs)
print(motion_mask.shape) # 输出:torch.Size([100, 64, 64])