File size: 2,125 Bytes
936caec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch

class STP:
    def __init__(self, tau_d=7e-3, tau_f=10e-3, U=0.15, C=1.0, device='cpu'):
        self.tau_d = tau_d  # 突触衰减时间常数
        self.tau_f = tau_f  # 突触恢复时间常数
        self.U = U          # 使用率
        self.C = C          # 输入脉冲强度
        self.device = device

    def update(self, R_prev, u_prev):
        """ 更新 STP 变量 """
        u_new = u_prev + self.U * (1 - u_prev)  # 更新 u
        R_new = R_prev * (1 - u_prev) + self.C  # 更新 R
        return R_new, u_new



def detect_motion(R_n, R_prev, theta=0.1):
    """ 计算运动掩码 M """
    motion_mask = (torch.abs(R_n - R_prev) >= theta).float()
    return motion_mask


class LIFNeuron:
    def __init__(self, tau_m=25e-6, v_rest=0.0, v_th=1.0, device='cpu'):
        self.tau_m = tau_m  # 膜时间常数
        self.v_rest = v_rest  # 静息电位
        self.v_th = v_th  # 阈值
        self.device = device

    def update(self, v_prev, input_current):
        """ 更新 LIF 神经元膜电位 """
        dv = (- (v_prev - self.v_rest) + input_current) / self.tau_m
        v_new = v_prev + dv
        spike = (v_new >= self.v_th).float()  # 触发 spike
        v_new = torch.where(spike > 0, self.v_rest, v_new)  # 超过阈值的重置
        return v_new, spike






# 初始化参数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
stp = STP(device=device)
lif = LIFNeuron(device=device)

# 假设有 t 帧的 spike 数据
num_frames = 100
height, width = 128, 128  # 图像尺寸
R = torch.zeros((height, width), device=device)
u = torch.zeros((height, width), device=device)
v = torch.zeros((height, width), device=device)

# 存储结果
motion_results = []

for t in range(num_frames):
    R_prev = R.clone()  # 记录上一个时间步的 R
    u_prev = u.clone()
    
    # STP 更新
    R, u = stp.update(R_prev, u_prev)
    
    # 运动检测
    motion_mask = detect_motion(R, R_prev, theta=0.1)
    
    # LIF 过滤
    v, spike_output = lif.update(v, motion_mask)
    
    # 保存结果
    motion_results.append(spike_output.cpu().numpy())