zzzzzeee's picture
Upload 28 files
9fa5305 verified
# -*- coding: utf-8 -*-
# @Time : 2023/7/16 20:13
# @Author : Yajing Zheng
# @Email: [email protected]
# @File : load_dat.py
import os, sys
import warnings
import glob
import yaml
import numpy as np
import path
# key-value for generate data loader according to the type of label data
LABEL_DATA_TYPE = {
'raw': 0,
'reconstruction': 1,
'optical_flow': 2,
'mono_depth_estimation': 3.1,
'stero_depth_estimation': 3.2,
'detection': 4,
'tracking': 5,
'recognition': 6
}
# generate parameters dictionary according to labeled or not
def data_parameter_dict(data_filename, label_type):
filename = path.split_path_into_pieces(data_filename)
if os.path.isabs(data_filename):
file_root = data_filename
if os.path.isdir(file_root):
search_root = file_root
else:
search_root = '\\'.join(filename[0:-1])
config_filename = path.seek_file(search_root, 'config.yaml')
else:
file_root = os.path.join('', 'datasets', *filename)
config_filename = os.path.join('', 'datasets', filename[0], 'config.yaml')
try:
with open(config_filename, 'r', encoding='utf-8') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
except TypeError as err:
print("Cannot find config file" + str(err))
raise err
except KeyError as exception:
print('ERROR! Task name does not exist')
print('Task name must be in %s' % LABEL_DATA_TYPE.keys())
raise exception
is_labeled = configs.get('is_labeled')
paraDict = {'spike_h': configs.get('spike_h'), 'spike_w': configs.get('spike_w')}
paraDict['filelist'] = None
if is_labeled:
paraDict['labeled_data_type'] = configs.get('labeled_data_type')
paraDict['labeled_data_suffix'] = configs.get('labeled_data_suffix')
paraDict['label_root_list'] = None
if os.path.isdir(file_root):
filelist = sorted(glob.glob(file_root + '/*.dat'), key=os.path.getmtime)
filepath = filelist[0]
labelname = path.replace_identifier(filename, configs.get('data_field_identifier', ''),
configs.get('label_field_identifier', ''))
label_root_list = os.path.join('', 'datasets', *labelname)
paraDict['labeled_data_dir'] = sorted(glob.glob(label_root_list + '/*.' + paraDict['labeled_data_suffix']),
key=os.path.getmtime)
paraDict['filelist'] = filelist
paraDict['label_root_list'] = label_root_list
else:
filepath = glob.glob(file_root)[0]
rawname = filename[-1].replace('.dat', '')
filename.pop(-1)
filename.append(rawname)
labelname = path.replace_identifier(filename, configs.get('data_field_identifier', ''),
configs.get('label_field_identifier', ''))
label_root = os.path.join('', 'datasets', *labelname)
paraDict['labeled_data_dir'] = glob.glob(label_root + '.' + paraDict['labeled_data_suffix'])[0]
else:
filepath = file_root
paraDict['filepath'] = filepath
return paraDict
class SpikeStream:
def __init__(self, **kwargs):
self.SpikeMatrix = None
self.filename = kwargs.get('filepath')
if os.path.splitext(self.filename)[-1][1:] != 'dat':
self.filename = self.filename + '.dat'
self.spike_w = kwargs.get('spike_w')
self.spike_h = kwargs.get('spike_h')
if 'print_dat_detail' not in kwargs:
self.print_dat_detail = True
else:
self.print_dat_detail = kwargs.get('print_dat_detail')
def get_spike_matrix(self, flipud=True, with_head=False):
file_reader = open(self.filename, 'rb')
video_seq = file_reader.read()
video_seq = np.frombuffer(video_seq, 'b')
video_seq = np.array(video_seq).astype(np.byte)
if self.print_dat_detail:
print(video_seq)
if with_head:
decode_width = 416
else:
decode_width = self.spike_w
# img_size = self.spike_height * self.spike_width
img_size = self.spike_h * decode_width
img_num = len(video_seq) // (img_size // 8)
if self.print_dat_detail:
print('loading total spikes from dat file -- spatial resolution: %d x %d, total timestamp: %d' %
(decode_width, self.spike_h, img_num))
# SpikeMatrix = np.zeros([img_num, self.spike_h, self.spike_width], np.byte)
pix_id = np.arange(0, img_num * self.spike_h * decode_width)
pix_id = np.reshape(pix_id, (img_num, self.spike_h, decode_width))
comparator = np.left_shift(1, np.mod(pix_id, 8))
byte_id = pix_id // 8
data = video_seq[byte_id]
result = np.bitwise_and(data, comparator)
tmp_matrix = (result == comparator)
# if with head, delete them
if with_head:
delete_indx = np.arange(400, 416)
tmp_matrix = np.delete(tmp_matrix, delete_indx, 2)
if flipud:
self.SpikeMatrix = tmp_matrix[:, ::-1, :]
else:
self.SpikeMatrix = tmp_matrix
file_reader.close()
self.SpikeMatrix = self.SpikeMatrix.astype(np.byte)
return self.SpikeMatrix
# return spikes with specified length and begin index
def get_block_spikes(self, begin_idx, block_len=1, flipud=True, with_head=False):
file_reader = open(self.filename, 'rb')
video_seq = file_reader.read()
video_seq = np.frombuffer(video_seq, 'b')
video_seq = np.array(video_seq).astype(np.uint8)
if with_head:
decode_width = 416
else:
decode_width = self.spike_w
# img_size = self.spike_height * self.spike_width
img_size = self.spike_h * decode_width
img_num = len(video_seq) // (img_size // 8)
end_idx = begin_idx + block_len
if end_idx > img_num:
warnings.warn("block_len exceeding upper limit! Zeros will be padded in the end. ", ResourceWarning)
end_idx = img_num
if self.print_dat_detail:
print(
'loading total spikes from dat file -- spatial resolution: %d x %d, begin index: %d total timestamp: %d' %
(decode_width, self.spike_h, begin_idx, block_len))
pix_id = np.arange(0, block_len * self.spike_h * decode_width)
pix_id = np.reshape(pix_id, (block_len, self.spike_h, decode_width))
comparator = np.left_shift(1, np.mod(pix_id, 8))
byte_id = pix_id // 8
id_start = begin_idx * img_size // 8
id_end = id_start + block_len * img_size // 8
data = video_seq[id_start:id_end]
data_frame = data[byte_id]
result = np.bitwise_and(data_frame, comparator)
tmp_matrix = (result == comparator)
# if with head, delete them
if with_head:
delete_indx = np.arange(400, 416)
tmp_matrix = np.delete(tmp_matrix, delete_indx, 2)
if flipud:
self.SpikeMatrix = tmp_matrix[:, ::-1, :]
else:
self.SpikeMatrix = tmp_matrix
file_reader.close()
self.SpikeMatrix = self.SpikeMatrix.astype(np.byte)
return self.SpikeMatrix