Spaces:
Runtime error
Runtime error
import torch | |
from utils_spike import STPFilter | |
def generate_motion_mask(spikes, spike_h, spike_w, device, diff_time=1, **STPargs): | |
""" | |
从脉冲数据中生成运动掩码,使用 STPFilter 类。 | |
参数: | |
spikes (torch.Tensor):脉冲数据,形状为 (time_steps, spike_h, spike_w),值为 0 或 1 | |
spike_h (int):脉冲数据的高度 | |
spike_w (int):脉冲数据的宽度 | |
device (torch.device):计算设备(如 torch.device('cuda') 或 torch.device('cpu')) | |
diff_time (int):记录历史动态的窗口长度,默认为 1 | |
**STPargs:STPFilter 初始化所需的额外参数(如 u0, D, F, f, time_unit 等) | |
返回: | |
motion_mask (torch.Tensor):运动掩码,形状为 (time_steps, spike_h, spike_w),值为 0 或 1 | |
""" | |
# 初始化 STPFilter 实例 | |
stp_filter = STPFilter(spike_h, spike_w, device, diff_time, **STPargs) | |
# 获取时间步数 | |
time_steps = spikes.shape[0] | |
# 预分配运动掩码张量,提高效率 | |
motion_mask = torch.zeros(time_steps, spike_h, spike_w, device=device) | |
# 逐时间步处理脉冲数据 | |
for t in range(time_steps): | |
cur_spikes = spikes[t] # 当前时间步的脉冲,形状为 (spike_h, spike_w) | |
stp_filter.update_dynamics(t, cur_spikes) # 更新动态并生成 filter_spk | |
motion_mask[t] = stp_filter.filter_spk # 记录当前时间步的运动掩码 | |
return motion_mask | |
# 示例用法 | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
spikes = torch.randint(0, 2, (100, 64, 64), dtype=torch.float, device=device) # 随机脉冲数据 | |
spike_h, spike_w = 64, 64 | |
# 使用默认参数 | |
motion_mask = generate_motion_mask(spikes, spike_h, spike_w, device) | |
# 使用自定义参数(参考图片描述) | |
STPargs = {'u0': 0.1, 'D': 200/2000, 'F': 400/2000, 'f': 0.11, 'time_unit': 2000} | |
motion_mask = generate_motion_mask(spikes, spike_h, spike_w, device, diff_time=5, **STPargs) | |
print(motion_mask.shape) # 输出:torch.Size([100, 64, 64]) |