File size: 3,261 Bytes
5fc3d65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os, time
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from recon.recon_bsf.bsf.bsf import BSF
from recon.recon_bsf.bsf.dsft_convert import convert_dsft4
from recon.recon_bsf.bsf_utils import InputPadder, spike2dsftpre4
from collections import OrderedDict






def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True):
    '''

    output: <class 'numpy.ndarray'> (frame_cnt, height, width) {0,1} float32

    '''
    array = np.fromfile(filename, dtype=np.uint8)

    len_per_frame = height * width // 8
    framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame

    spikes = []
    for i in range(framecnt):
        compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame]
        blist = []
        for b in range(8):
            blist.append(np.right_shift(np.bitwise_and(
                compr_frame, np.left_shift(1, b)), b))

        frame_ = np.stack(blist).transpose()
        frame_ = frame_.reshape((height, width), order='C')
        if reverse_spike:
            frame_ = np.flipud(frame_)
        spikes.append(frame_)

    return np.array(spikes).astype(np.float32)


def spikes_to_middletfi(spike, middle, window=50):
    C, H, W = spike.shape
    lindex, rindex = torch.zeros([H, W]), torch.zeros([H, W])
    l, r = middle+1, middle+1
    for r in range(middle+1, middle + window+1):
        l = l - 1
        if l>=0:
            newpos = spike[l, :, :]*(1 - torch.sign(lindex)) 
            distance = l*newpos
            lindex += distance
        if r<C:
            newpos = spike[r, :, :]*(1 - torch.sign(rindex))
            distance = r*newpos
            rindex += distance
        if l<0 and r>=C:
            break
    rindex[rindex == 0] = window + middle
    lindex[lindex == 0] = middle - window
    interval = rindex - lindex
    tfi = 1.0 / interval
    tfi = tfi.unsqueeze(0) 
    return tfi.float() 


def spikes_to_tfp(spike, idx, halfwsize):
    # real size of window == 2*halfwsize+1
    spike_ = spike[idx-halfwsize:idx+halfwsize]
    tfp_img = torch.mean(spike_, axis=0)
    spike_min, spike_max = torch.min(tfp_img), torch.max(tfp_img)
    tfp_img = (tfp_img - spike_min) / (spike_max - spike_min)
    return tfp_img
    

def spikes_to_bsf(spike, model, padder, central_idx, device):
    model.eval()
    max_search_half_window = min((len(spike) - central_idx), central_idx) //2
    # import pdb; pdb.set_trace()
    spike, dsfts = spike2dsftpre4(spike, central_idx=central_idx, half_window=30, max_search_half_window=max_search_half_window)
    spike, dsfts = torch.from_numpy(spike).float().to(device).unsqueeze(0), torch.from_numpy(dsfts).float().to(device).unsqueeze(0)
    # pad
    spike, dsfts = padder.pad(spike, dsfts)
    # dsft-m
    dsft_dict = convert_dsft4(spike=spike, dsft=dsfts)
    input_dict = {
        'dsft_dict': dsft_dict,
        'spikes': spike,
    }
    with torch.no_grad():
        rec = model(input_dict=input_dict)
    # unpad
    rec = padder.unpad(rec)
    # save to image
    rec = torch.clip(rec[0], 0, 1) # TODO
    rec = rec.detach().cpu()
    return rec