File size: 7,614 Bytes
9fa5305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# -*- coding: utf-8 -*- 
# @Time : 2023/7/16 20:13 
# @Author : Yajing Zheng
# @Email: [email protected]
# @File : load_dat.py
import os, sys
import warnings
import glob
import yaml
import numpy as np
import path

# key-value for generate data loader according to the type of label data
LABEL_DATA_TYPE = {
    'raw': 0,
    'reconstruction': 1,
    'optical_flow': 2,
    'mono_depth_estimation': 3.1,
    'stero_depth_estimation': 3.2,
    'detection': 4,
    'tracking': 5,
    'recognition': 6
}


# generate parameters dictionary according to labeled or not
def data_parameter_dict(data_filename, label_type):
    filename = path.split_path_into_pieces(data_filename)

    if os.path.isabs(data_filename):
        file_root = data_filename
        if os.path.isdir(file_root):
            search_root = file_root
        else:
            search_root = '\\'.join(filename[0:-1])
        config_filename = path.seek_file(search_root, 'config.yaml')
    else:
        file_root = os.path.join('', 'datasets', *filename)
        config_filename = os.path.join('', 'datasets', filename[0], 'config.yaml')

    try:
        with open(config_filename, 'r', encoding='utf-8') as fin:
            configs = yaml.load(fin, Loader=yaml.FullLoader)
    except TypeError as err:
        print("Cannot find config file" + str(err))
        raise err

    except KeyError as exception:
        print('ERROR! Task name does not exist')
        print('Task name must be in %s' % LABEL_DATA_TYPE.keys())
        raise exception

    is_labeled = configs.get('is_labeled')

    paraDict = {'spike_h': configs.get('spike_h'), 'spike_w': configs.get('spike_w')}
    paraDict['filelist'] = None

    if is_labeled:
        paraDict['labeled_data_type'] = configs.get('labeled_data_type')
        paraDict['labeled_data_suffix'] = configs.get('labeled_data_suffix')
        paraDict['label_root_list'] = None

        if os.path.isdir(file_root):
            filelist = sorted(glob.glob(file_root + '/*.dat'), key=os.path.getmtime)
            filepath = filelist[0]

            labelname = path.replace_identifier(filename, configs.get('data_field_identifier', ''),
                                                configs.get('label_field_identifier', ''))
            label_root_list = os.path.join('', 'datasets', *labelname)
            paraDict['labeled_data_dir'] = sorted(glob.glob(label_root_list + '/*.' + paraDict['labeled_data_suffix']),
                                                  key=os.path.getmtime)

            paraDict['filelist'] = filelist
            paraDict['label_root_list'] = label_root_list
        else:
            filepath = glob.glob(file_root)[0]
            rawname = filename[-1].replace('.dat', '')
            filename.pop(-1)
            filename.append(rawname)
            labelname = path.replace_identifier(filename, configs.get('data_field_identifier', ''),
                                                configs.get('label_field_identifier', ''))
            label_root = os.path.join('', 'datasets', *labelname)
            paraDict['labeled_data_dir'] = glob.glob(label_root + '.' + paraDict['labeled_data_suffix'])[0]
    else:
        filepath = file_root

    paraDict['filepath'] = filepath

    return paraDict


class SpikeStream:
    def __init__(self, **kwargs):

        self.SpikeMatrix = None
        self.filename = kwargs.get('filepath')
        if os.path.splitext(self.filename)[-1][1:] != 'dat':
            self.filename = self.filename + '.dat'
        self.spike_w = kwargs.get('spike_w')
        self.spike_h = kwargs.get('spike_h')
        if 'print_dat_detail' not in kwargs:
            self.print_dat_detail = True
        else:
            self.print_dat_detail = kwargs.get('print_dat_detail')

    def get_spike_matrix(self, flipud=True, with_head=False):

        file_reader = open(self.filename, 'rb')
        video_seq = file_reader.read()
        video_seq = np.frombuffer(video_seq, 'b')

        video_seq = np.array(video_seq).astype(np.byte)
        if self.print_dat_detail:
            print(video_seq)
        if with_head:
            decode_width = 416
        else:
            decode_width = self.spike_w
        # img_size = self.spike_height * self.spike_width
        img_size = self.spike_h * decode_width
        img_num = len(video_seq) // (img_size // 8)

        if self.print_dat_detail:
            print('loading total spikes from dat file -- spatial resolution: %d x %d, total timestamp: %d' %
                  (decode_width, self.spike_h, img_num))

        # SpikeMatrix = np.zeros([img_num, self.spike_h, self.spike_width], np.byte)

        pix_id = np.arange(0, img_num * self.spike_h * decode_width)
        pix_id = np.reshape(pix_id, (img_num, self.spike_h, decode_width))
        comparator = np.left_shift(1, np.mod(pix_id, 8))
        byte_id = pix_id // 8

        data = video_seq[byte_id]
        result = np.bitwise_and(data, comparator)
        tmp_matrix = (result == comparator)

        # if with head, delete them
        if with_head:
            delete_indx = np.arange(400, 416)
            tmp_matrix = np.delete(tmp_matrix, delete_indx, 2)

        if flipud:
            self.SpikeMatrix = tmp_matrix[:, ::-1, :]
        else:
            self.SpikeMatrix = tmp_matrix

        file_reader.close()
        self.SpikeMatrix = self.SpikeMatrix.astype(np.byte)
        return self.SpikeMatrix

        # return spikes with specified length and begin index

    def get_block_spikes(self, begin_idx, block_len=1, flipud=True, with_head=False):

        file_reader = open(self.filename, 'rb')
        video_seq = file_reader.read()
        video_seq = np.frombuffer(video_seq, 'b')

        video_seq = np.array(video_seq).astype(np.uint8)

        if with_head:
            decode_width = 416
        else:
            decode_width = self.spike_w
        # img_size = self.spike_height * self.spike_width
        img_size = self.spike_h * decode_width
        img_num = len(video_seq) // (img_size // 8)

        end_idx = begin_idx + block_len
        if end_idx > img_num:
            warnings.warn("block_len exceeding upper limit! Zeros will be padded in the end. ", ResourceWarning)
            end_idx = img_num

        if self.print_dat_detail:
            print(
                'loading total spikes from dat file -- spatial resolution: %d x %d, begin index: %d total timestamp: %d' %
                (decode_width, self.spike_h, begin_idx, block_len))

        pix_id = np.arange(0, block_len * self.spike_h * decode_width)
        pix_id = np.reshape(pix_id, (block_len, self.spike_h, decode_width))
        comparator = np.left_shift(1, np.mod(pix_id, 8))
        byte_id = pix_id // 8
        id_start = begin_idx * img_size // 8
        id_end = id_start + block_len * img_size // 8
        data = video_seq[id_start:id_end]
        data_frame = data[byte_id]
        result = np.bitwise_and(data_frame, comparator)
        tmp_matrix = (result == comparator)

        # if with head, delete them
        if with_head:
            delete_indx = np.arange(400, 416)
            tmp_matrix = np.delete(tmp_matrix, delete_indx, 2)

        if flipud:
            self.SpikeMatrix = tmp_matrix[:, ::-1, :]
        else:
            self.SpikeMatrix = tmp_matrix

        file_reader.close()
        self.SpikeMatrix = self.SpikeMatrix.astype(np.byte)
        return self.SpikeMatrix