|
import os, time
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from PIL import Image
|
|
from recon.recon_bsf.bsf.bsf import BSF
|
|
from recon.recon_bsf.bsf.dsft_convert import convert_dsft4
|
|
from recon.recon_bsf.bsf_utils import InputPadder, spike2dsftpre4
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True):
|
|
'''
|
|
output: <class 'numpy.ndarray'> (frame_cnt, height, width) {0,1} float32
|
|
'''
|
|
array = np.fromfile(filename, dtype=np.uint8)
|
|
|
|
len_per_frame = height * width // 8
|
|
framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame
|
|
|
|
spikes = []
|
|
for i in range(framecnt):
|
|
compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame]
|
|
blist = []
|
|
for b in range(8):
|
|
blist.append(np.right_shift(np.bitwise_and(
|
|
compr_frame, np.left_shift(1, b)), b))
|
|
|
|
frame_ = np.stack(blist).transpose()
|
|
frame_ = frame_.reshape((height, width), order='C')
|
|
if reverse_spike:
|
|
frame_ = np.flipud(frame_)
|
|
spikes.append(frame_)
|
|
|
|
return np.array(spikes).astype(np.float32)
|
|
|
|
|
|
def spikes_to_middletfi(spike, middle, window=50):
|
|
C, H, W = spike.shape
|
|
lindex, rindex = torch.zeros([H, W]), torch.zeros([H, W])
|
|
l, r = middle+1, middle+1
|
|
for r in range(middle+1, middle + window+1):
|
|
l = l - 1
|
|
if l>=0:
|
|
newpos = spike[l, :, :]*(1 - torch.sign(lindex))
|
|
distance = l*newpos
|
|
lindex += distance
|
|
if r<C:
|
|
newpos = spike[r, :, :]*(1 - torch.sign(rindex))
|
|
distance = r*newpos
|
|
rindex += distance
|
|
if l<0 and r>=C:
|
|
break
|
|
rindex[rindex == 0] = window + middle
|
|
lindex[lindex == 0] = middle - window
|
|
interval = rindex - lindex
|
|
tfi = 1.0 / interval
|
|
tfi = tfi.unsqueeze(0)
|
|
return tfi.float()
|
|
|
|
|
|
def spikes_to_tfp(spike, idx, halfwsize):
|
|
|
|
spike_ = spike[idx-halfwsize:idx+halfwsize]
|
|
tfp_img = torch.mean(spike_, axis=0)
|
|
spike_min, spike_max = torch.min(tfp_img), torch.max(tfp_img)
|
|
tfp_img = (tfp_img - spike_min) / (spike_max - spike_min)
|
|
return tfp_img
|
|
|
|
|
|
def spikes_to_bsf(spike, model, padder, central_idx, device):
|
|
model.eval()
|
|
max_search_half_window = min((len(spike) - central_idx), central_idx) //2
|
|
|
|
spike, dsfts = spike2dsftpre4(spike, central_idx=central_idx, half_window=30, max_search_half_window=max_search_half_window)
|
|
spike, dsfts = torch.from_numpy(spike).float().to(device).unsqueeze(0), torch.from_numpy(dsfts).float().to(device).unsqueeze(0)
|
|
|
|
spike, dsfts = padder.pad(spike, dsfts)
|
|
|
|
dsft_dict = convert_dsft4(spike=spike, dsft=dsfts)
|
|
input_dict = {
|
|
'dsft_dict': dsft_dict,
|
|
'spikes': spike,
|
|
}
|
|
with torch.no_grad():
|
|
rec = model(input_dict=input_dict)
|
|
|
|
rec = padder.unpad(rec)
|
|
|
|
rec = torch.clip(rec[0], 0, 1)
|
|
rec = rec.detach().cpu()
|
|
return rec
|
|
|
|
|
|
|
|
|