rightnow / utils_spike.py
zzzzzeee's picture
Upload 101 files
5fc3d65 verified
import os, time
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from recon.recon_bsf.bsf.bsf import BSF
from recon.recon_bsf.bsf.dsft_convert import convert_dsft4
from recon.recon_bsf.bsf_utils import InputPadder, spike2dsftpre4
from collections import OrderedDict
def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True):
'''
output: <class 'numpy.ndarray'> (frame_cnt, height, width) {0,1} float32
'''
array = np.fromfile(filename, dtype=np.uint8)
len_per_frame = height * width // 8
framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame
spikes = []
for i in range(framecnt):
compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame]
blist = []
for b in range(8):
blist.append(np.right_shift(np.bitwise_and(
compr_frame, np.left_shift(1, b)), b))
frame_ = np.stack(blist).transpose()
frame_ = frame_.reshape((height, width), order='C')
if reverse_spike:
frame_ = np.flipud(frame_)
spikes.append(frame_)
return np.array(spikes).astype(np.float32)
def spikes_to_middletfi(spike, middle, window=50):
C, H, W = spike.shape
lindex, rindex = torch.zeros([H, W]), torch.zeros([H, W])
l, r = middle+1, middle+1
for r in range(middle+1, middle + window+1):
l = l - 1
if l>=0:
newpos = spike[l, :, :]*(1 - torch.sign(lindex))
distance = l*newpos
lindex += distance
if r<C:
newpos = spike[r, :, :]*(1 - torch.sign(rindex))
distance = r*newpos
rindex += distance
if l<0 and r>=C:
break
rindex[rindex == 0] = window + middle
lindex[lindex == 0] = middle - window
interval = rindex - lindex
tfi = 1.0 / interval
tfi = tfi.unsqueeze(0)
return tfi.float()
def spikes_to_tfp(spike, idx, halfwsize):
# real size of window == 2*halfwsize+1
spike_ = spike[idx-halfwsize:idx+halfwsize]
tfp_img = torch.mean(spike_, axis=0)
spike_min, spike_max = torch.min(tfp_img), torch.max(tfp_img)
tfp_img = (tfp_img - spike_min) / (spike_max - spike_min)
return tfp_img
def spikes_to_bsf(spike, model, padder, central_idx, device):
model.eval()
max_search_half_window = min((len(spike) - central_idx), central_idx) //2
# import pdb; pdb.set_trace()
spike, dsfts = spike2dsftpre4(spike, central_idx=central_idx, half_window=30, max_search_half_window=max_search_half_window)
spike, dsfts = torch.from_numpy(spike).float().to(device).unsqueeze(0), torch.from_numpy(dsfts).float().to(device).unsqueeze(0)
# pad
spike, dsfts = padder.pad(spike, dsfts)
# dsft-m
dsft_dict = convert_dsft4(spike=spike, dsft=dsfts)
input_dict = {
'dsft_dict': dsft_dict,
'spikes': spike,
}
with torch.no_grad():
rec = model(input_dict=input_dict)
# unpad
rec = padder.unpad(rec)
# save to image
rec = torch.clip(rec[0], 0, 1) # TODO
rec = rec.detach().cpu()
return rec