motionmask / utils_spike.py
zzzzzeee's picture
Update utils_spike.py
672f88e verified
raw
history blame
8.02 kB
import os, time
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True):
'''
output: <class 'numpy.ndarray'> (frame_cnt, height, width) {0,1} float32
'''
array = np.fromfile(filename, dtype=np.uint8)
len_per_frame = height * width // 8
framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame
spikes = []
for i in range(framecnt):
compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame]
blist = []
for b in range(8):
blist.append(np.right_shift(np.bitwise_and(
compr_frame, np.left_shift(1, b)), b))
frame_ = np.stack(blist).transpose()
frame_ = frame_.reshape((height, width), order='C')
if reverse_spike:
frame_ = np.flipud(frame_)
spikes.append(frame_)
return np.array(spikes).astype(np.float32)
class STPFilter:
def __init__(self, spike_h, spike_w, device, diff_time=1, **STPargs):
self.spike_h = spike_h
self.spike_w = spike_w
self.device = device
# specify stp parameters
if STPargs.get('u0', None) is None:
self.u0 = 0.1
self.D = 0.02
self.F = 1.7
self.f = 0.11
self.time_unit = 2000
else:
self.u0 = STPargs.get('u0')
self.D = STPargs.get('D')
self.F = STPargs.get('F')
self.f = STPargs.get('f')
self.time_unit = STPargs.get('time_unit')
self.r0 = 1
self.diff_time = diff_time # duration of window for record past dynamics for calculating the differnece
self.R = torch.ones(self.spike_h, self.spike_w) * self.r0
self.u = torch.ones(self.spike_h, self.spike_w) * self.u0
self.r_old = torch.ones(self.diff_time, self.spike_h, self.spike_w) * self.r0
self.R = self.R.to(self.device)
self.u = self.u.to(self.device)
self.r_old = self.r_old.to(self.device)
# LIF detect layer parameters
self.detectVoltage = torch.zeros(self.spike_h, self.spike_w).to(self.device)
if STPargs.get('lifSize', None) is None:
lifSize = 3
paddingSize = 1
else:
lifSize = STPargs.get('lifSize')
paddingSize = int((lifSize - 1) / 2)
self.lifConv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(lifSize, lifSize),
padding=(paddingSize, paddingSize),
bias=False)
self.lifConv.weight.data = torch.ones(1, 1, lifSize, lifSize) * 3.0
self.lifConv = self.lifConv.to(self.device)
if STPargs.get('filterThr', None) is None:
self.filterThr = 0.1 # filter threshold
self.voltageMin = -8
self.lifThr = 2
else:
self.filterThr = STPargs.get('filterThr')
self.voltageMin = STPargs.get('voltageMin')
self.lifThr = STPargs.get('lifThr')
self.filter_spk = torch.zeros(self.spike_h, self.spike_w).to(self.device)
self.lif_spk = torch.zeros(self.spike_h, self.spike_w).to(self.device)
self.spikePrevMnt = torch.zeros([self.spike_h, self.spike_w], device=self.device)
self.stp_gradient = 0
self.adjusted_threshold = torch.zeros(self.spike_h, self.spike_w).to(self.device)
def update_dynamics(self, curT, spikes):
spikeCurMnt = self.spikePrevMnt.detach().clone()
spike_bool = spikes.bool()
spikeCurMnt[spike_bool] = curT + 1
dttimes = spikeCurMnt - self.spikePrevMnt
dttimes = dttimes / self.time_unit
exp_D = torch.exp((-dttimes[spike_bool] / self.D))
self.R[spike_bool] = 1 - (1 - self.R[spike_bool] * (1 - self.u[spike_bool])) * exp_D
exp_F = torch.exp((-dttimes[spike_bool] / self.F))
self.u[spike_bool] = self.u0 + (
self.u[spike_bool] + self.f * (1 - self.u[spike_bool]) - self.u0) * exp_F
tmp_diff = torch.abs(self.R - self.r_old[0])
# 根据梯度动态调整滤波器阈值
self.stp_gradient = (0.5 * self.stp_gradient + 0.5 * torch.div(tmp_diff, self.R))
gradient_sqrt = torch.from_numpy(np.sqrt(self.stp_gradient.cpu().numpy()) + 1).to(self.device)
self.adjusted_threshold = torch.div(self.filterThr, gradient_sqrt)
self.filter_spk[:] = 0
# self.filter_spk[spike_bool & (tmp_diff >= self.filterThr)] = 1
self.filter_spk[spike_bool & (tmp_diff >= self.adjusted_threshold)] = 1
if curT < self.diff_time:
self.r_old[curT] = self.R.detach().clone()
else:
self.r_old[0:-1] = self.r_old[1:].detach().clone()
self.r_old[-1] = self.R.detach().clone()
self.spikePrevMnt = spikeCurMnt.detach().clone()
del spikeCurMnt, dttimes, exp_D, exp_F, tmp_diff
def update_dynamic_offline(self, spikes, intervals):
isi_num = intervals.shape[0]
R = torch.ones(isi_num, self.spike_h, self.spike_w) * self.r0
u = torch.ones(isi_num, self.spike_h, self.spike_w) * self.u0
prev_isi = intervals[0, :, :]
for t in range(1, isi_num):
tmp_isi = intervals[t, :, :]
update_idx = (tmp_isi != prev_isi) & (spikes[t, :, :] == 1) | (tmp_isi == 1)
tmp_isi = torch.from_numpy(tmp_isi).to(self.device).float()
exp_D = torch.exp((-tmp_isi[update_idx] / self.D))
self.R[update_idx] = 1 - (1 - self.R[update_idx] * (1 - self.u[update_idx])) * exp_D
exp_F = torch.exp((-tmp_isi[update_idx] / self.F))
self.u[update_idx] = self.u0 + (
self.u[update_idx] + self.f * (1 - self.u[update_idx]) - self.u0) * exp_F
tmp_r = self.R.detach().clone()
tmp_u = self.u.detach().clone()
R[t, :, :] = copy.deepcopy(tmp_r)
u[t, :, :] = copy.deepcopy(tmp_u)
return R, u
def local_connect(self, spikes):
inputSpk = torch.reshape(spikes, (1, 1, self.spike_h, self.spike_w)).float()
# tmp_fired = spikes != 0
self.detectVoltage[spikes == False] -= 1
tmpRes = self.lifConv(inputSpk)
tmpRes = torch.squeeze(tmpRes).to(self.device)
self.detectVoltage += tmpRes.data
self.detectVoltage[self.detectVoltage < self.voltageMin] = self.voltageMin
self.lif_spk[:] = 0
self.lif_spk[self.detectVoltage >= self.lifThr] = 1
self.detectVoltage[self.detectVoltage >= self.lifThr] *= 0.8
# self.detectVoltage[(self.detectVoltage < self.lifThr) & (self.detectVoltage > 0)] = 0
del inputSpk, tmpRes
def local_connect_offline(self, spikes):
timestamps = spikes.shape[0]
tmp_voltage = []
lif_spk = []
for iSpk in range(timestamps):
tmp_spikes = spikes[iSpk]
tmp_spk = torch.from_numpy(spikes[iSpk]).to(self.device)
inputSpk = torch.reshape(tmp_spk, (1, 1, self.spike_h, self.spike_w)).float()
# tmp_fired = spikes != 0
self.detectVoltage[tmp_spikes == 0] -= 1
tmpRes = self.lifConv(inputSpk)
tmpRes = torch.squeeze(tmpRes).to(self.device)
self.detectVoltage += tmpRes.data
self.detectVoltage[self.detectVoltage < self.voltageMin] = self.voltageMin
self.lif_spk[:] = 0
self.lif_spk[self.detectVoltage >= self.lifThr] = 1
# self.detectVoltage[(self.detectVoltage < self.lifThr) & (self.detectVoltage > 0)] = 0
self.detectVoltage[self.detectVoltage >= self.lifThr] *= 0.8
voltage = self.detectVoltage.cpu().detach().numpy()
tmp_voltage.append(copy.deepcopy(voltage))
lif_spk.append(self.lif_spk.cpu().detach().numpy())
del inputSpk, tmpRes
return tmp_voltage, lif_spk