Spaces:
Running
Running
File size: 7,614 Bytes
9fa5305 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
# -*- coding: utf-8 -*-
# @Time : 2023/7/16 20:13
# @Author : Yajing Zheng
# @Email: [email protected]
# @File : load_dat.py
import os, sys
import warnings
import glob
import yaml
import numpy as np
import path
# key-value for generate data loader according to the type of label data
LABEL_DATA_TYPE = {
'raw': 0,
'reconstruction': 1,
'optical_flow': 2,
'mono_depth_estimation': 3.1,
'stero_depth_estimation': 3.2,
'detection': 4,
'tracking': 5,
'recognition': 6
}
# generate parameters dictionary according to labeled or not
def data_parameter_dict(data_filename, label_type):
filename = path.split_path_into_pieces(data_filename)
if os.path.isabs(data_filename):
file_root = data_filename
if os.path.isdir(file_root):
search_root = file_root
else:
search_root = '\\'.join(filename[0:-1])
config_filename = path.seek_file(search_root, 'config.yaml')
else:
file_root = os.path.join('', 'datasets', *filename)
config_filename = os.path.join('', 'datasets', filename[0], 'config.yaml')
try:
with open(config_filename, 'r', encoding='utf-8') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
except TypeError as err:
print("Cannot find config file" + str(err))
raise err
except KeyError as exception:
print('ERROR! Task name does not exist')
print('Task name must be in %s' % LABEL_DATA_TYPE.keys())
raise exception
is_labeled = configs.get('is_labeled')
paraDict = {'spike_h': configs.get('spike_h'), 'spike_w': configs.get('spike_w')}
paraDict['filelist'] = None
if is_labeled:
paraDict['labeled_data_type'] = configs.get('labeled_data_type')
paraDict['labeled_data_suffix'] = configs.get('labeled_data_suffix')
paraDict['label_root_list'] = None
if os.path.isdir(file_root):
filelist = sorted(glob.glob(file_root + '/*.dat'), key=os.path.getmtime)
filepath = filelist[0]
labelname = path.replace_identifier(filename, configs.get('data_field_identifier', ''),
configs.get('label_field_identifier', ''))
label_root_list = os.path.join('', 'datasets', *labelname)
paraDict['labeled_data_dir'] = sorted(glob.glob(label_root_list + '/*.' + paraDict['labeled_data_suffix']),
key=os.path.getmtime)
paraDict['filelist'] = filelist
paraDict['label_root_list'] = label_root_list
else:
filepath = glob.glob(file_root)[0]
rawname = filename[-1].replace('.dat', '')
filename.pop(-1)
filename.append(rawname)
labelname = path.replace_identifier(filename, configs.get('data_field_identifier', ''),
configs.get('label_field_identifier', ''))
label_root = os.path.join('', 'datasets', *labelname)
paraDict['labeled_data_dir'] = glob.glob(label_root + '.' + paraDict['labeled_data_suffix'])[0]
else:
filepath = file_root
paraDict['filepath'] = filepath
return paraDict
class SpikeStream:
def __init__(self, **kwargs):
self.SpikeMatrix = None
self.filename = kwargs.get('filepath')
if os.path.splitext(self.filename)[-1][1:] != 'dat':
self.filename = self.filename + '.dat'
self.spike_w = kwargs.get('spike_w')
self.spike_h = kwargs.get('spike_h')
if 'print_dat_detail' not in kwargs:
self.print_dat_detail = True
else:
self.print_dat_detail = kwargs.get('print_dat_detail')
def get_spike_matrix(self, flipud=True, with_head=False):
file_reader = open(self.filename, 'rb')
video_seq = file_reader.read()
video_seq = np.frombuffer(video_seq, 'b')
video_seq = np.array(video_seq).astype(np.byte)
if self.print_dat_detail:
print(video_seq)
if with_head:
decode_width = 416
else:
decode_width = self.spike_w
# img_size = self.spike_height * self.spike_width
img_size = self.spike_h * decode_width
img_num = len(video_seq) // (img_size // 8)
if self.print_dat_detail:
print('loading total spikes from dat file -- spatial resolution: %d x %d, total timestamp: %d' %
(decode_width, self.spike_h, img_num))
# SpikeMatrix = np.zeros([img_num, self.spike_h, self.spike_width], np.byte)
pix_id = np.arange(0, img_num * self.spike_h * decode_width)
pix_id = np.reshape(pix_id, (img_num, self.spike_h, decode_width))
comparator = np.left_shift(1, np.mod(pix_id, 8))
byte_id = pix_id // 8
data = video_seq[byte_id]
result = np.bitwise_and(data, comparator)
tmp_matrix = (result == comparator)
# if with head, delete them
if with_head:
delete_indx = np.arange(400, 416)
tmp_matrix = np.delete(tmp_matrix, delete_indx, 2)
if flipud:
self.SpikeMatrix = tmp_matrix[:, ::-1, :]
else:
self.SpikeMatrix = tmp_matrix
file_reader.close()
self.SpikeMatrix = self.SpikeMatrix.astype(np.byte)
return self.SpikeMatrix
# return spikes with specified length and begin index
def get_block_spikes(self, begin_idx, block_len=1, flipud=True, with_head=False):
file_reader = open(self.filename, 'rb')
video_seq = file_reader.read()
video_seq = np.frombuffer(video_seq, 'b')
video_seq = np.array(video_seq).astype(np.uint8)
if with_head:
decode_width = 416
else:
decode_width = self.spike_w
# img_size = self.spike_height * self.spike_width
img_size = self.spike_h * decode_width
img_num = len(video_seq) // (img_size // 8)
end_idx = begin_idx + block_len
if end_idx > img_num:
warnings.warn("block_len exceeding upper limit! Zeros will be padded in the end. ", ResourceWarning)
end_idx = img_num
if self.print_dat_detail:
print(
'loading total spikes from dat file -- spatial resolution: %d x %d, begin index: %d total timestamp: %d' %
(decode_width, self.spike_h, begin_idx, block_len))
pix_id = np.arange(0, block_len * self.spike_h * decode_width)
pix_id = np.reshape(pix_id, (block_len, self.spike_h, decode_width))
comparator = np.left_shift(1, np.mod(pix_id, 8))
byte_id = pix_id // 8
id_start = begin_idx * img_size // 8
id_end = id_start + block_len * img_size // 8
data = video_seq[id_start:id_end]
data_frame = data[byte_id]
result = np.bitwise_and(data_frame, comparator)
tmp_matrix = (result == comparator)
# if with head, delete them
if with_head:
delete_indx = np.arange(400, 416)
tmp_matrix = np.delete(tmp_matrix, delete_indx, 2)
if flipud:
self.SpikeMatrix = tmp_matrix[:, ::-1, :]
else:
self.SpikeMatrix = tmp_matrix
file_reader.close()
self.SpikeMatrix = self.SpikeMatrix.astype(np.byte)
return self.SpikeMatrix
|