import sys import numpy as np import torch import threading import cv2 import json # import matplotlib # matplotlib.use('TkAgg') import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from matplotlib.pyplot import MultipleLocator class dataReader(threading.Thread): def __init__(self, file_reader, device, q, is_dat=True, is_npy=False, filedir=None): super(dataReader, self).__init__() self.file_reader = file_reader self.device = device self.q = q self.is_dat = is_dat self.is_npy = is_npy self.filedir = filedir self.stream = torch.cuda.Stream() def run(self): with torch.cuda.stream(self.stream): for t in range(tnum): if self.is_dat: ibuffer = self.file_reader.read(int(ivs_w * ivs_h / 8)) a = bin(int.from_bytes(ibuffer, byteorder=sys.byteorder)) a = a[2:].zfill(ivs_w * ivs_h) a = list(a) a = np.array(a, dtype=np.byte) a = np.reshape(a, [ivs_h, ivs_w]) if ivs_h == 600: a = np.flip(a, 0) if ivs_h == 250: a = np.flip(a, 1) input_spk = torch.from_numpy(a != 0).to(device) elif self.is_npy: npy_filename = self.filedir + str(t + 442) + '.npy' tmp_data = np.load(npy_filename) superResolution_rate = tmp_data.shape[2] for i_data in range(superResolution_rate): tmp_spk = tmp_data[:, :, i_data] input_spk = torch.from_numpy(tmp_spk).to(device) self.q.put(input_spk) else: # img_filename = self.filedir + str(t + 4200) + '.png' img_filename = self.filedir + 'spike_' + str(t + 1) + '.png' # print('reading %d frames' % (t+1)) # print('reading %d frames' % (t+5000)) a = cv2.imread(img_filename) a = cv2.cvtColor(a, cv2.COLOR_BGR2GRAY) a = a / 255 a = np.array(a, dtype=np.byte) input_spk = torch.from_numpy(a != 0).to(device) self.q.put(input_spk) # obtain 2D gaussian filter def get_kernel(filter_size, sigma): assert (filter_size + 1) % 2 == 0, '2D filter size must be odd number!' g = np.zeros((filter_size, filter_size), dtype=np.float32) half_width = int((filter_size - 1) / 2) # center location xc = (filter_size + 1) / 2 yc = (filter_size + 1) / 2 for i in range(-half_width, half_width + 1, 1): for j in range(-half_width, half_width + 1, 1): x = int(xc + i) y = int(yc + j) g[y - 1, x - 1] = np.exp(- (i ** 2 + j ** 2) / 2 / sigma / sigma) g = (g - g.min()) / (g.max() - g.min()) return g def get_transform_matrix(ori, speed): ori_num = len(ori) speed_num = len(speed) transform_matrix = torch.zeros(ori_num * speed_num, 2, 3) cnt = 0 for iOri in range(ori_num): for iSpeed in range(speed_num): transform_matrix[cnt, 0, 0] = 1 transform_matrix[cnt, 1, 1] = 1 transform_matrix[cnt, 0, 2] = - float(ori[iOri, 1] * speed[iSpeed] / ivs_w) transform_matrix[cnt, 1, 2] = - float(ori[iOri, 0] * speed[iSpeed] / ivs_h) cnt += 1 transform_matrix = transform_matrix.to(device) return transform_matrix def get_transform_matrix_new(ori, speed, dvs_w, dvs_h, device): ori_num = len(ori) speed_num = len(speed) transform_matrix = torch.zeros(ori_num * speed_num, 2, 3) cnt = 0 for iOri in range(ori_num): for iSpeed in range(speed_num): transform_matrix[cnt, 0, 0] = 1 transform_matrix[cnt, 1, 1] = 1 transform_matrix[cnt, 0, 2] = - float(ori[iOri, 1] * speed[iSpeed] / dvs_w) transform_matrix[cnt, 1, 2] = - float(ori[iOri, 0] * speed[iSpeed] / dvs_h) cnt += 1 transform_matrix = transform_matrix.to(device) return transform_matrix # monitor the inference process def visualize_img(gray_img, tag, curT): gray_img = gray_img.float32() img = torch.unsqueeze(gray_img, 0) logger.add_image(tag, img, global_step=curT) def visualize_images(images, tag, curT): if images.shape[0] < 1: return images = torch.squeeze(images) img_num = images.shape[-1] for iImg in range(img_num): tmp_img = images[:, :, iImg] tmp_img = torch.squeeze(tmp_img) tmp_img = torch.unsqueeze(tmp_img, 0) logger.add_image(tag + str(iImg), tmp_img, global_step=curT) def visualize_weights(weights, tag, curT): if weights.shape[0] < 1: return weights = torch.squeeze(weights) weights_num = weights.shape[0] input_size = weights.shape[1] stim_size = int(np.sqrt(input_size)) for iw in range(weights_num): tmp_w = weights[iw, :] tmp_w = torch.squeeze(tmp_w) tmp_w = (tmp_w - torch.min(tmp_w)) / (torch.max(tmp_w) - torch.min(tmp_w)) tmp_w = torch.reshape(tmp_w, (stim_size, stim_size)) tmp_w = torch.unsqueeze(tmp_w, 0) logger.add_image(tag + str(iw), tmp_w, global_step=curT) class NumpyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray): return obj.tolist() return json.JSONEncoder.default(self, obj) def vis_trajectory(json_file, filename, **dataDict): spike_h = dataDict.get('spike_h') spike_w = dataDict.get('spike_w') traj_dict = [] with open(json_file, 'r') as f: for line in f.readlines(): traj_dict.append(json.loads(line)) num_traj = len(traj_dict) fig = plt.figure(figsize=[10, 6]) ax = fig.add_subplot(111, projection='3d') min_t = 1000 max_t = 0 for tmp_traj in traj_dict: tmp_t = np.array(tmp_traj['t']) if np.min(tmp_t) < min_t: min_t = np.min(tmp_t) if np.max(tmp_t) > max_t: max_t = np.max(tmp_t) tmp_x = spike_w - np.array(tmp_traj['x']) tmp_y = np.array(tmp_traj['y']) tmp_color = np.array(tmp_traj['color']) / 255. ax.plot(tmp_t, tmp_x, tmp_y, color=tmp_color, linewidth=2, label='traj ' + str(tmp_traj['id'])) ax.legend(loc='best', bbox_to_anchor=(0.7, 0., 0.4, 0.8)) zoom = [2.2, 0.8, 0.5, 1] ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([zoom[0], zoom[1], zoom[2], zoom[3]])) ax.set_xlim(min_t, max_t) ax.set_ylim(0, spike_w) ax.set_zlim(0, spike_h) ax.set_xlabel('time', fontsize=15) ax.set_ylabel('width', fontsize=15) ax.set_zlabel('height', fontsize=15) ax.view_init(elev=16, azim=135) ax.yaxis.set_major_locator(MultipleLocator(100)) fig.subplots_adjust(top=1., bottom=0., left=0.2, right=1.) # fig.tight_layout() plt.show() plt.savefig(filename, dpi=500, transparent=True)