Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# @Time : 2023/7/16 20:13 | |
# @Author : Yajing Zheng | |
# @Email: [email protected] | |
# @File : load_dat.py | |
import os, sys | |
import warnings | |
import glob | |
import yaml | |
import numpy as np | |
import path | |
# key-value for generate data loader according to the type of label data | |
LABEL_DATA_TYPE = { | |
'raw': 0, | |
'reconstruction': 1, | |
'optical_flow': 2, | |
'mono_depth_estimation': 3.1, | |
'stero_depth_estimation': 3.2, | |
'detection': 4, | |
'tracking': 5, | |
'recognition': 6 | |
} | |
# generate parameters dictionary according to labeled or not | |
def data_parameter_dict(data_filename, label_type): | |
filename = path.split_path_into_pieces(data_filename) | |
if os.path.isabs(data_filename): | |
file_root = data_filename | |
if os.path.isdir(file_root): | |
search_root = file_root | |
else: | |
search_root = '\\'.join(filename[0:-1]) | |
config_filename = path.seek_file(search_root, 'config.yaml') | |
else: | |
file_root = os.path.join('', 'datasets', *filename) | |
config_filename = os.path.join('', 'datasets', filename[0], 'config.yaml') | |
try: | |
with open(config_filename, 'r', encoding='utf-8') as fin: | |
configs = yaml.load(fin, Loader=yaml.FullLoader) | |
except TypeError as err: | |
print("Cannot find config file" + str(err)) | |
raise err | |
except KeyError as exception: | |
print('ERROR! Task name does not exist') | |
print('Task name must be in %s' % LABEL_DATA_TYPE.keys()) | |
raise exception | |
is_labeled = configs.get('is_labeled') | |
paraDict = {'spike_h': configs.get('spike_h'), 'spike_w': configs.get('spike_w')} | |
paraDict['filelist'] = None | |
if is_labeled: | |
paraDict['labeled_data_type'] = configs.get('labeled_data_type') | |
paraDict['labeled_data_suffix'] = configs.get('labeled_data_suffix') | |
paraDict['label_root_list'] = None | |
if os.path.isdir(file_root): | |
filelist = sorted(glob.glob(file_root + '/*.dat'), key=os.path.getmtime) | |
filepath = filelist[0] | |
labelname = path.replace_identifier(filename, configs.get('data_field_identifier', ''), | |
configs.get('label_field_identifier', '')) | |
label_root_list = os.path.join('', 'datasets', *labelname) | |
paraDict['labeled_data_dir'] = sorted(glob.glob(label_root_list + '/*.' + paraDict['labeled_data_suffix']), | |
key=os.path.getmtime) | |
paraDict['filelist'] = filelist | |
paraDict['label_root_list'] = label_root_list | |
else: | |
filepath = glob.glob(file_root)[0] | |
rawname = filename[-1].replace('.dat', '') | |
filename.pop(-1) | |
filename.append(rawname) | |
labelname = path.replace_identifier(filename, configs.get('data_field_identifier', ''), | |
configs.get('label_field_identifier', '')) | |
label_root = os.path.join('', 'datasets', *labelname) | |
paraDict['labeled_data_dir'] = glob.glob(label_root + '.' + paraDict['labeled_data_suffix'])[0] | |
else: | |
filepath = file_root | |
paraDict['filepath'] = filepath | |
return paraDict | |
class SpikeStream: | |
def __init__(self, **kwargs): | |
self.SpikeMatrix = None | |
self.filename = kwargs.get('filepath') | |
if os.path.splitext(self.filename)[-1][1:] != 'dat': | |
self.filename = self.filename + '.dat' | |
self.spike_w = kwargs.get('spike_w') | |
self.spike_h = kwargs.get('spike_h') | |
if 'print_dat_detail' not in kwargs: | |
self.print_dat_detail = True | |
else: | |
self.print_dat_detail = kwargs.get('print_dat_detail') | |
def get_spike_matrix(self, flipud=True, with_head=False): | |
file_reader = open(self.filename, 'rb') | |
video_seq = file_reader.read() | |
video_seq = np.frombuffer(video_seq, 'b') | |
video_seq = np.array(video_seq).astype(np.byte) | |
if self.print_dat_detail: | |
print(video_seq) | |
if with_head: | |
decode_width = 416 | |
else: | |
decode_width = self.spike_w | |
# img_size = self.spike_height * self.spike_width | |
img_size = self.spike_h * decode_width | |
img_num = len(video_seq) // (img_size // 8) | |
if self.print_dat_detail: | |
print('loading total spikes from dat file -- spatial resolution: %d x %d, total timestamp: %d' % | |
(decode_width, self.spike_h, img_num)) | |
# SpikeMatrix = np.zeros([img_num, self.spike_h, self.spike_width], np.byte) | |
pix_id = np.arange(0, img_num * self.spike_h * decode_width) | |
pix_id = np.reshape(pix_id, (img_num, self.spike_h, decode_width)) | |
comparator = np.left_shift(1, np.mod(pix_id, 8)) | |
byte_id = pix_id // 8 | |
data = video_seq[byte_id] | |
result = np.bitwise_and(data, comparator) | |
tmp_matrix = (result == comparator) | |
# if with head, delete them | |
if with_head: | |
delete_indx = np.arange(400, 416) | |
tmp_matrix = np.delete(tmp_matrix, delete_indx, 2) | |
if flipud: | |
self.SpikeMatrix = tmp_matrix[:, ::-1, :] | |
else: | |
self.SpikeMatrix = tmp_matrix | |
file_reader.close() | |
self.SpikeMatrix = self.SpikeMatrix.astype(np.byte) | |
return self.SpikeMatrix | |
# return spikes with specified length and begin index | |
def get_block_spikes(self, begin_idx, block_len=1, flipud=True, with_head=False): | |
file_reader = open(self.filename, 'rb') | |
video_seq = file_reader.read() | |
video_seq = np.frombuffer(video_seq, 'b') | |
video_seq = np.array(video_seq).astype(np.uint8) | |
if with_head: | |
decode_width = 416 | |
else: | |
decode_width = self.spike_w | |
# img_size = self.spike_height * self.spike_width | |
img_size = self.spike_h * decode_width | |
img_num = len(video_seq) // (img_size // 8) | |
end_idx = begin_idx + block_len | |
if end_idx > img_num: | |
warnings.warn("block_len exceeding upper limit! Zeros will be padded in the end. ", ResourceWarning) | |
end_idx = img_num | |
if self.print_dat_detail: | |
print( | |
'loading total spikes from dat file -- spatial resolution: %d x %d, begin index: %d total timestamp: %d' % | |
(decode_width, self.spike_h, begin_idx, block_len)) | |
pix_id = np.arange(0, block_len * self.spike_h * decode_width) | |
pix_id = np.reshape(pix_id, (block_len, self.spike_h, decode_width)) | |
comparator = np.left_shift(1, np.mod(pix_id, 8)) | |
byte_id = pix_id // 8 | |
id_start = begin_idx * img_size // 8 | |
id_end = id_start + block_len * img_size // 8 | |
data = video_seq[id_start:id_end] | |
data_frame = data[byte_id] | |
result = np.bitwise_and(data_frame, comparator) | |
tmp_matrix = (result == comparator) | |
# if with head, delete them | |
if with_head: | |
delete_indx = np.arange(400, 416) | |
tmp_matrix = np.delete(tmp_matrix, delete_indx, 2) | |
if flipud: | |
self.SpikeMatrix = tmp_matrix[:, ::-1, :] | |
else: | |
self.SpikeMatrix = tmp_matrix | |
file_reader.close() | |
self.SpikeMatrix = self.SpikeMatrix.astype(np.byte) | |
return self.SpikeMatrix | |