Spaces:
Runtime error
Runtime error
import torch | |
class STP: | |
def __init__(self, tau_d=7e-3, tau_f=10e-3, U=0.15, C=1.0, device='cpu'): | |
self.tau_d = tau_d # 突触衰减时间常数 | |
self.tau_f = tau_f # 突触恢复时间常数 | |
self.U = U # 使用率 | |
self.C = C # 输入脉冲强度 | |
self.device = device | |
def update(self, R_prev, u_prev): | |
""" 更新 STP 变量 """ | |
u_new = u_prev + self.U * (1 - u_prev) # 更新 u | |
R_new = R_prev * (1 - u_prev) + self.C # 更新 R | |
return R_new, u_new | |
def detect_motion(R_n, R_prev, theta=0.1): | |
""" 计算运动掩码 M """ | |
motion_mask = (torch.abs(R_n - R_prev) >= theta).float() | |
return motion_mask | |
class LIFNeuron: | |
def __init__(self, tau_m=25e-6, v_rest=0.0, v_th=1.0, device='cpu'): | |
self.tau_m = tau_m # 膜时间常数 | |
self.v_rest = v_rest # 静息电位 | |
self.v_th = v_th # 阈值 | |
self.device = device | |
def update(self, v_prev, input_current): | |
""" 更新 LIF 神经元膜电位 """ | |
dv = (- (v_prev - self.v_rest) + input_current) / self.tau_m | |
v_new = v_prev + dv | |
spike = (v_new >= self.v_th).float() # 触发 spike | |
v_new = torch.where(spike > 0, self.v_rest, v_new) # 超过阈值的重置 | |
return v_new, spike | |
# 初始化参数 | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
stp = STP(device=device) | |
lif = LIFNeuron(device=device) | |
# 假设有 t 帧的 spike 数据 | |
num_frames = 100 | |
height, width = 128, 128 # 图像尺寸 | |
R = torch.zeros((height, width), device=device) | |
u = torch.zeros((height, width), device=device) | |
v = torch.zeros((height, width), device=device) | |
# 存储结果 | |
motion_results = [] | |
for t in range(num_frames): | |
R_prev = R.clone() # 记录上一个时间步的 R | |
u_prev = u.clone() | |
# STP 更新 | |
R, u = stp.update(R_prev, u_prev) | |
# 运动检测 | |
motion_mask = detect_motion(R, R_prev, theta=0.1) | |
# LIF 过滤 | |
v, spike_output = lif.update(v, motion_mask) | |
# 保存结果 | |
motion_results.append(spike_output.cpu().numpy()) |