import os, time import cv2 import numpy as np import torch import torch.nn.functional as F from PIL import Image from recon.recon_bsf.bsf.bsf import BSF from recon.recon_bsf.bsf.dsft_convert import convert_dsft4 from recon.recon_bsf.bsf_utils import InputPadder, spike2dsftpre4 from collections import OrderedDict def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True): ''' output: (frame_cnt, height, width) {0,1} float32 ''' array = np.fromfile(filename, dtype=np.uint8) len_per_frame = height * width // 8 framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame spikes = [] for i in range(framecnt): compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame] blist = [] for b in range(8): blist.append(np.right_shift(np.bitwise_and( compr_frame, np.left_shift(1, b)), b)) frame_ = np.stack(blist).transpose() frame_ = frame_.reshape((height, width), order='C') if reverse_spike: frame_ = np.flipud(frame_) spikes.append(frame_) return np.array(spikes).astype(np.float32) def spikes_to_middletfi(spike, middle, window=50): C, H, W = spike.shape lindex, rindex = torch.zeros([H, W]), torch.zeros([H, W]) l, r = middle+1, middle+1 for r in range(middle+1, middle + window+1): l = l - 1 if l>=0: newpos = spike[l, :, :]*(1 - torch.sign(lindex)) distance = l*newpos lindex += distance if r=C: break rindex[rindex == 0] = window + middle lindex[lindex == 0] = middle - window interval = rindex - lindex tfi = 1.0 / interval tfi = tfi.unsqueeze(0) return tfi.float() def spikes_to_tfp(spike, idx, halfwsize): # real size of window == 2*halfwsize+1 spike_ = spike[idx-halfwsize:idx+halfwsize] tfp_img = torch.mean(spike_, axis=0) spike_min, spike_max = torch.min(tfp_img), torch.max(tfp_img) tfp_img = (tfp_img - spike_min) / (spike_max - spike_min) return tfp_img def spikes_to_bsf(spike, model, padder, central_idx, device): model.eval() max_search_half_window = min((len(spike) - central_idx), central_idx) //2 # import pdb; pdb.set_trace() spike, dsfts = spike2dsftpre4(spike, central_idx=central_idx, half_window=30, max_search_half_window=max_search_half_window) spike, dsfts = torch.from_numpy(spike).float().to(device).unsqueeze(0), torch.from_numpy(dsfts).float().to(device).unsqueeze(0) # pad spike, dsfts = padder.pad(spike, dsfts) # dsft-m dsft_dict = convert_dsft4(spike=spike, dsft=dsfts) input_dict = { 'dsft_dict': dsft_dict, 'spikes': spike, } with torch.no_grad(): rec = model(input_dict=input_dict) # unpad rec = padder.unpad(rec) # save to image rec = torch.clip(rec[0], 0, 1) # TODO rec = rec.detach().cpu() return rec