File size: 2,080 Bytes
9ce00e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from utils_spike import STPFilter

def generate_motion_mask(spikes, spike_h, spike_w, device, diff_time=1, **STPargs):
    """
    从脉冲数据中生成运动掩码,使用 STPFilter 类。

    参数:
        spikes (torch.Tensor):脉冲数据,形状为 (time_steps, spike_h, spike_w),值为 0 或 1
        spike_h (int):脉冲数据的高度
        spike_w (int):脉冲数据的宽度
        device (torch.device):计算设备(如 torch.device('cuda') 或 torch.device('cpu'))
        diff_time (int):记录历史动态的窗口长度,默认为 1
        **STPargs:STPFilter 初始化所需的额外参数(如 u0, D, F, f, time_unit 等)

    返回:
        motion_mask (torch.Tensor):运动掩码,形状为 (time_steps, spike_h, spike_w),值为 0 或 1
    """
    # 初始化 STPFilter 实例
    stp_filter = STPFilter(spike_h, spike_w, device, diff_time, **STPargs)
    
    # 获取时间步数
    time_steps = spikes.shape[0]
    
    # 预分配运动掩码张量,提高效率
    motion_mask = torch.zeros(time_steps, spike_h, spike_w, device=device)
    
    # 逐时间步处理脉冲数据
    for t in range(time_steps):
        cur_spikes = spikes[t]  # 当前时间步的脉冲,形状为 (spike_h, spike_w)
        stp_filter.update_dynamics(t, cur_spikes)  # 更新动态并生成 filter_spk
        motion_mask[t] = stp_filter.filter_spk  # 记录当前时间步的运动掩码
    
    return motion_mask




# 示例用法
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
spikes = torch.randint(0, 2, (100, 64, 64), dtype=torch.float, device=device)  # 随机脉冲数据
spike_h, spike_w = 64, 64

# 使用默认参数
motion_mask = generate_motion_mask(spikes, spike_h, spike_w, device)

# 使用自定义参数(参考图片描述)
STPargs = {'u0': 0.1, 'D': 200/2000, 'F': 400/2000, 'f': 0.11, 'time_unit': 2000}
motion_mask = generate_motion_mask(spikes, spike_h, spike_w, device, diff_time=5, **STPargs)

print(motion_mask.shape)  # 输出:torch.Size([100, 64, 64])