zzzzzeee's picture
Upload 28 files
9fa5305 verified
import sys
import numpy as np
import torch
import threading
import cv2
import json
# import matplotlib
# matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.pyplot import MultipleLocator
class dataReader(threading.Thread):
def __init__(self, file_reader, device, q, is_dat=True, is_npy=False, filedir=None):
super(dataReader, self).__init__()
self.file_reader = file_reader
self.device = device
self.q = q
self.is_dat = is_dat
self.is_npy = is_npy
self.filedir = filedir
self.stream = torch.cuda.Stream()
def run(self):
with torch.cuda.stream(self.stream):
for t in range(tnum):
if self.is_dat:
ibuffer = self.file_reader.read(int(ivs_w * ivs_h / 8))
a = bin(int.from_bytes(ibuffer, byteorder=sys.byteorder))
a = a[2:].zfill(ivs_w * ivs_h)
a = list(a)
a = np.array(a, dtype=np.byte)
a = np.reshape(a, [ivs_h, ivs_w])
if ivs_h == 600:
a = np.flip(a, 0)
if ivs_h == 250:
a = np.flip(a, 1)
input_spk = torch.from_numpy(a != 0).to(device)
elif self.is_npy:
npy_filename = self.filedir + str(t + 442) + '.npy'
tmp_data = np.load(npy_filename)
superResolution_rate = tmp_data.shape[2]
for i_data in range(superResolution_rate):
tmp_spk = tmp_data[:, :, i_data]
input_spk = torch.from_numpy(tmp_spk).to(device)
self.q.put(input_spk)
else:
# img_filename = self.filedir + str(t + 4200) + '.png'
img_filename = self.filedir + 'spike_' + str(t + 1) + '.png'
# print('reading %d frames' % (t+1))
# print('reading %d frames' % (t+5000))
a = cv2.imread(img_filename)
a = cv2.cvtColor(a, cv2.COLOR_BGR2GRAY)
a = a / 255
a = np.array(a, dtype=np.byte)
input_spk = torch.from_numpy(a != 0).to(device)
self.q.put(input_spk)
# obtain 2D gaussian filter
def get_kernel(filter_size, sigma):
assert (filter_size + 1) % 2 == 0, '2D filter size must be odd number!'
g = np.zeros((filter_size, filter_size), dtype=np.float32)
half_width = int((filter_size - 1) / 2)
# center location
xc = (filter_size + 1) / 2
yc = (filter_size + 1) / 2
for i in range(-half_width, half_width + 1, 1):
for j in range(-half_width, half_width + 1, 1):
x = int(xc + i)
y = int(yc + j)
g[y - 1, x - 1] = np.exp(- (i ** 2 + j ** 2) / 2 / sigma / sigma)
g = (g - g.min()) / (g.max() - g.min())
return g
def get_transform_matrix(ori, speed):
ori_num = len(ori)
speed_num = len(speed)
transform_matrix = torch.zeros(ori_num * speed_num, 2, 3)
cnt = 0
for iOri in range(ori_num):
for iSpeed in range(speed_num):
transform_matrix[cnt, 0, 0] = 1
transform_matrix[cnt, 1, 1] = 1
transform_matrix[cnt, 0, 2] = - float(ori[iOri, 1] * speed[iSpeed] / ivs_w)
transform_matrix[cnt, 1, 2] = - float(ori[iOri, 0] * speed[iSpeed] / ivs_h)
cnt += 1
transform_matrix = transform_matrix.to(device)
return transform_matrix
def get_transform_matrix_new(ori, speed, dvs_w, dvs_h, device):
ori_num = len(ori)
speed_num = len(speed)
transform_matrix = torch.zeros(ori_num * speed_num, 2, 3)
cnt = 0
for iOri in range(ori_num):
for iSpeed in range(speed_num):
transform_matrix[cnt, 0, 0] = 1
transform_matrix[cnt, 1, 1] = 1
transform_matrix[cnt, 0, 2] = - float(ori[iOri, 1] * speed[iSpeed] / dvs_w)
transform_matrix[cnt, 1, 2] = - float(ori[iOri, 0] * speed[iSpeed] / dvs_h)
cnt += 1
transform_matrix = transform_matrix.to(device)
return transform_matrix
# monitor the inference process
def visualize_img(gray_img, tag, curT):
gray_img = gray_img.float32()
img = torch.unsqueeze(gray_img, 0)
logger.add_image(tag, img, global_step=curT)
def visualize_images(images, tag, curT):
if images.shape[0] < 1:
return
images = torch.squeeze(images)
img_num = images.shape[-1]
for iImg in range(img_num):
tmp_img = images[:, :, iImg]
tmp_img = torch.squeeze(tmp_img)
tmp_img = torch.unsqueeze(tmp_img, 0)
logger.add_image(tag + str(iImg), tmp_img, global_step=curT)
def visualize_weights(weights, tag, curT):
if weights.shape[0] < 1:
return
weights = torch.squeeze(weights)
weights_num = weights.shape[0]
input_size = weights.shape[1]
stim_size = int(np.sqrt(input_size))
for iw in range(weights_num):
tmp_w = weights[iw, :]
tmp_w = torch.squeeze(tmp_w)
tmp_w = (tmp_w - torch.min(tmp_w)) / (torch.max(tmp_w) - torch.min(tmp_w))
tmp_w = torch.reshape(tmp_w, (stim_size, stim_size))
tmp_w = torch.unsqueeze(tmp_w, 0)
logger.add_image(tag + str(iw), tmp_w, global_step=curT)
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def vis_trajectory(json_file, filename, **dataDict):
spike_h = dataDict.get('spike_h')
spike_w = dataDict.get('spike_w')
traj_dict = []
with open(json_file, 'r') as f:
for line in f.readlines():
traj_dict.append(json.loads(line))
num_traj = len(traj_dict)
fig = plt.figure(figsize=[10, 6])
ax = fig.add_subplot(111, projection='3d')
min_t = 1000
max_t = 0
for tmp_traj in traj_dict:
tmp_t = np.array(tmp_traj['t'])
if np.min(tmp_t) < min_t:
min_t = np.min(tmp_t)
if np.max(tmp_t) > max_t:
max_t = np.max(tmp_t)
tmp_x = spike_w - np.array(tmp_traj['x'])
tmp_y = np.array(tmp_traj['y'])
tmp_color = np.array(tmp_traj['color']) / 255.
ax.plot(tmp_t, tmp_x, tmp_y, color=tmp_color, linewidth=2, label='traj ' + str(tmp_traj['id']))
ax.legend(loc='best', bbox_to_anchor=(0.7, 0., 0.4, 0.8))
zoom = [2.2, 0.8, 0.5, 1]
ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([zoom[0], zoom[1], zoom[2], zoom[3]]))
ax.set_xlim(min_t, max_t)
ax.set_ylim(0, spike_w)
ax.set_zlim(0, spike_h)
ax.set_xlabel('time', fontsize=15)
ax.set_ylabel('width', fontsize=15)
ax.set_zlabel('height', fontsize=15)
ax.view_init(elev=16, azim=135)
ax.yaxis.set_major_locator(MultipleLocator(100))
fig.subplots_adjust(top=1., bottom=0., left=0.2, right=1.)
# fig.tight_layout()
plt.show()
plt.savefig(filename, dpi=500, transparent=True)