motionmask / chatgpt.py
zzzzzeee's picture
Create chatgpt.py
936caec verified
import torch
class STP:
def __init__(self, tau_d=7e-3, tau_f=10e-3, U=0.15, C=1.0, device='cpu'):
self.tau_d = tau_d # 突触衰减时间常数
self.tau_f = tau_f # 突触恢复时间常数
self.U = U # 使用率
self.C = C # 输入脉冲强度
self.device = device
def update(self, R_prev, u_prev):
""" 更新 STP 变量 """
u_new = u_prev + self.U * (1 - u_prev) # 更新 u
R_new = R_prev * (1 - u_prev) + self.C # 更新 R
return R_new, u_new
def detect_motion(R_n, R_prev, theta=0.1):
""" 计算运动掩码 M """
motion_mask = (torch.abs(R_n - R_prev) >= theta).float()
return motion_mask
class LIFNeuron:
def __init__(self, tau_m=25e-6, v_rest=0.0, v_th=1.0, device='cpu'):
self.tau_m = tau_m # 膜时间常数
self.v_rest = v_rest # 静息电位
self.v_th = v_th # 阈值
self.device = device
def update(self, v_prev, input_current):
""" 更新 LIF 神经元膜电位 """
dv = (- (v_prev - self.v_rest) + input_current) / self.tau_m
v_new = v_prev + dv
spike = (v_new >= self.v_th).float() # 触发 spike
v_new = torch.where(spike > 0, self.v_rest, v_new) # 超过阈值的重置
return v_new, spike
# 初始化参数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
stp = STP(device=device)
lif = LIFNeuron(device=device)
# 假设有 t 帧的 spike 数据
num_frames = 100
height, width = 128, 128 # 图像尺寸
R = torch.zeros((height, width), device=device)
u = torch.zeros((height, width), device=device)
v = torch.zeros((height, width), device=device)
# 存储结果
motion_results = []
for t in range(num_frames):
R_prev = R.clone() # 记录上一个时间步的 R
u_prev = u.clone()
# STP 更新
R, u = stp.update(R_prev, u_prev)
# 运动检测
motion_mask = detect_motion(R, R_prev, theta=0.1)
# LIF 过滤
v, spike_output = lif.update(v, motion_mask)
# 保存结果
motion_results.append(spike_output.cpu().numpy())