import torch from utils_spike import STPFilter def generate_motion_mask(spikes, spike_h, spike_w, device, diff_time=1, **STPargs): """ 从脉冲数据中生成运动掩码,使用 STPFilter 类。 参数: spikes (torch.Tensor):脉冲数据,形状为 (time_steps, spike_h, spike_w),值为 0 或 1 spike_h (int):脉冲数据的高度 spike_w (int):脉冲数据的宽度 device (torch.device):计算设备(如 torch.device('cuda') 或 torch.device('cpu')) diff_time (int):记录历史动态的窗口长度,默认为 1 **STPargs:STPFilter 初始化所需的额外参数(如 u0, D, F, f, time_unit 等) 返回: motion_mask (torch.Tensor):运动掩码,形状为 (time_steps, spike_h, spike_w),值为 0 或 1 """ # 初始化 STPFilter 实例 stp_filter = STPFilter(spike_h, spike_w, device, diff_time, **STPargs) # 获取时间步数 time_steps = spikes.shape[0] # 预分配运动掩码张量,提高效率 motion_mask = torch.zeros(time_steps, spike_h, spike_w, device=device) # 逐时间步处理脉冲数据 for t in range(time_steps): cur_spikes = spikes[t] # 当前时间步的脉冲,形状为 (spike_h, spike_w) stp_filter.update_dynamics(t, cur_spikes) # 更新动态并生成 filter_spk motion_mask[t] = stp_filter.filter_spk # 记录当前时间步的运动掩码 return motion_mask # 示例用法 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') spikes = torch.randint(0, 2, (100, 64, 64), dtype=torch.float, device=device) # 随机脉冲数据 spike_h, spike_w = 64, 64 # 使用默认参数 motion_mask = generate_motion_mask(spikes, spike_h, spike_w, device) # 使用自定义参数(参考图片描述) STPargs = {'u0': 0.1, 'D': 200/2000, 'F': 400/2000, 'f': 0.11, 'time_unit': 2000} motion_mask = generate_motion_mask(spikes, spike_h, spike_w, device, diff_time=5, **STPargs) print(motion_mask.shape) # 输出:torch.Size([100, 64, 64])