zzzzzeee's picture
Upload 28 files
9fa5305 verified
# -*- coding: utf-8 -*-
# @Time : 2023/7/16 20:23
# @Author : Yajing Zheng
# @Email: [email protected]
# @File : snn_tracker.py
import os, sys
sys.path.append('../..')
import time
import numpy as np
import torch
from spkProc.filters.stp_filters_torch import STPFilter
# from filters import stpFilter
from spkProc.detection.attention_select import SaccadeInput
from spkProc.motion.motion_detection import motion_estimation
from spkProc.detection.stdp_clustering import stdp_cluster
from utils import NumpyEncoder
from collections import namedtuple
import json
import cv2
from tqdm import tqdm
trajectories = namedtuple('trajectories', ['id', 'x', 'y', 't', 'color'])
class SNNTracker:
def __init__(self, spike_h, spike_w, device, attention_size=20, diff_time=1, **STPargs):
self.spike_h = spike_h
self.spike_w = spike_w
self.device = device
# self.stp_filter = STPFilter(spike_h, spike_w, device)
if STPargs is not None:
self.stp_filter = STPFilter(spike_h, spike_w, device, diff_time, **STPargs)
else:
self.stp_filter = STPFilter(spike_h, spike_w, device, diff_time)
# self.stp_filter = stpFilter()
self.attention_size = attention_size
self.object_detection = SaccadeInput(spike_h, spike_w, box_size=self.attention_size, device=device)
from tensorboardX import SummaryWriter
logger = SummaryWriter(log_dir='data/log_pkuvidar')
self.motion_estimator = motion_estimation(spike_h, spike_w, device, logger=logger)
# gpu_tracker.track() # run function between the code line where uses GPU
self.object_cluster = stdp_cluster(spike_h, spike_w, box_size=self.attention_size, device=device)
# self.timestamps = spikes.shape[0]
# self.filterd_spikes = np.zeros([self.timestamps, self.spike_h, self.spike_w], np.uint8)
self.calibration_time = 150
self.timestamps = 0
self.trajectories = {}
self.filterd_spikes = []
def calibrate_motion(self, spikes, calibration_time=None):
if calibration_time is None:
calibration_time = self.calibration_time
else:
self.calibration_time = calibration_time
print('begin calibrate..')
for t in range(calibration_time):
input_spk = torch.from_numpy(spikes[t, :, :]).to(self.device)
self.stp_filter.update_dynamics(t, input_spk)
self.timestamps += 1
def get_results(self, spikes, res_filepath, mov_writer=None, save_video=False):
result_file = open(res_filepath, 'a+')
timestamps = spikes.shape[0]
total_time = 0
predict_kwargs = {'spike_h': self.spike_h, 'spike_w': self.spike_w, 'device': self.device}
for t in tqdm(range(timestamps), desc=f'Saving tracking results to {str(result_file)}'):
try:
input_spk = torch.from_numpy(spikes[t, :, :]).to(self.device)
self.stp_filter.update_dynamics(self.timestamps, input_spk)
self.stp_filter.local_connect(self.stp_filter.filter_spk)
# self.filterd_spikes[t, :, :] = self.stp_filter.lif_spk.cpu().detach().numpy()
self.object_detection.update_dnf(self.stp_filter.lif_spk)
attentionBox, attentionInput = self.object_detection.get_attention_location(self.stp_filter.lif_spk)
# attentionInput = attentionInput.to(self.device)
num_box = attentionBox.shape[0]
self.motion_estimator.stdp_tracking(self.stp_filter.lif_spk)
motion_id, motion_vector, _ = self.motion_estimator.local_wta(self.stp_filter.lif_spk, self.timestamps, visualize=True)
# gpu_tracker.track() # run function between the code line where uses GPU
predict_fire, sw, bw = self.object_cluster.update_weight(attentionInput)
predict_object = self.object_cluster.detect_object(predict_fire, attentionBox, motion_id, motion_vector, **predict_kwargs)
# visualize_weights(sw, 'before update tracks', t)
sw, bw = self.object_cluster.update_tracks(predict_object, sw, bw, self.timestamps)
self.object_cluster.synaptic_weight = sw.detach().clone()
self.object_cluster.bias_weight = bw.detach().clone()
dets = torch.zeros((num_box, 6), dtype=torch.int)
for i_box, bbox in enumerate(attentionBox):
dets[i_box, :] = torch.tensor([bbox[0], bbox[1], bbox[2], bbox[3], 1, 1])
track_ids = []
if save_video:
track_frame = self.stp_filter.lif_spk.cpu().numpy()
track_frame = (track_frame * 255).astype(np.uint8)
# track_frame = np.transpose(track_frame, (1, 2, 0))
# track_frame = np.tile(track_frame, (3, 1, 1))
# track_frame = np.squeeze(track_frame)
track_frame = cv2.cvtColor(track_frame, cv2.COLOR_GRAY2BGR)
for i_box in range(attentionBox.shape[0]):
tmp_box = attentionBox[i_box, :]
cv2.rectangle(track_frame, (int(tmp_box[1]), int(tmp_box[0])), (int(tmp_box[3]), int(tmp_box[2])),
(int(0), int(0), int(255)), 2)
for i_box in range(self.object_cluster.K2):
if self.object_cluster.tracks[i_box].visible == 1:
tmp_box = self.object_cluster.tracks[i_box].bbox.numpy()
pred_box = self.object_cluster.tracks[i_box].predbox.numpy()
id = self.object_cluster.tracks[i_box].id
color = self.object_cluster.tracks[i_box].color
# update the trajectories
mid_y = (tmp_box[0, 0] + tmp_box[0, 2]) / 2 # height
mid_x = (tmp_box[0, 1] + tmp_box[0, 3]) / 2 # width
box_w = int(tmp_box[0, 3] - tmp_box[0, 1])
box_h = int(tmp_box[0,2] - tmp_box[0, 0])
print('%d,%d,%.2f,%.2f,%.2f,%.2f,1,-1,-1,-1' % (
self.timestamps, id, tmp_box[0, 1], tmp_box[0, 0], box_w, box_h), file=result_file)
if id not in self.trajectories:
self.trajectories[id] = trajectories(int(id), [], [], [], 255 * np.random.rand(1, 3))
self.trajectories[id].x.append(mid_x)
self.trajectories[id].y.append(mid_y)
self.trajectories[id].t.append(self.timestamps)
else:
self.trajectories[id].x.append(mid_x)
self.trajectories[id].y.append(mid_y)
self.trajectories[id].t.append(self.timestamps)
# the detection results
if save_video:
cv2.rectangle(track_frame, (int(tmp_box[0, 1]), int(tmp_box[0, 0])),
(int(tmp_box[0, 3]), int(tmp_box[0, 2])),
(int(color[0, 0]), int(color[0, 1]), int(color[0, 2])), 2)
# # the predicted results
# cv2.rectangle(track_frame, (int(pred_box[0, 1]), int(pred_box[0, 0])),
# (int(pred_box[0, 3]), int(pred_box[0, 2])), (int(0), int(0), int(255)), 2)
# the label box
cv2.rectangle(track_frame, (int(tmp_box[0, 1]), int(tmp_box[0, 0] - 35)),
(int(tmp_box[0, 1] + 60), int(tmp_box[0, 0])),
(int(color[0, 0]), int(color[0, 1]), int(color[0, 2])), -1)
if self.object_cluster.tracks[i_box].unvisible_count > 0:
show_text = 'predict' + str(id)
else:
show_text = 'object' + str(id)
cv2.putText(track_frame, show_text, (int(tmp_box[0, 1]), int(tmp_box[0, 0] - 10)),
cv2.FONT_HERSHEY_SIMPLEX,
1, (255, 255, 255), 2)
if save_video:
cv2.putText(track_frame, str(int(self.timestamps)),
(10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 255), 2)
mov_writer.write(track_frame)
self.timestamps += 1
except RuntimeError as exception:
if "out of memory" in str(exception):
print('WARNING: out of memory')
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else:
raise exception
print('Total tracking took: %.3f seconds for %d timestamps spikes' %
(total_time, self.timestamps - self.calibration_time))
# if save_video:
# mov_writer.release()
# cv2.destroyAllWindows()
result_file.close()
def save_trajectory(self, results_dir, data_name):
trajectories_filename = os.path.join(results_dir, data_name + '_py.json')
mat_trajectories_filename = 'results/' + data_name + '.json'
track_box_filename = 'results/' + data_name + '_bbox.json'
if os.path.exists(trajectories_filename):
os.remove(trajectories_filename)
if os.path.exists(mat_trajectories_filename):
os.remove(mat_trajectories_filename)
if os.path.exists(track_box_filename):
os.remove(track_box_filename)
for i_traj in range(self.object_cluster.K2):
tmp_traj = self.object_cluster.trajectories[i_traj]
tmp_bbox = self.object_cluster.tracks_bbox[i_traj]
traj_json_string = json.dumps(tmp_traj._asdict(), cls=NumpyEncoder)
bbox_json_string = json.dumps(tmp_bbox._asdict(), cls=NumpyEncoder)
with open(mat_trajectories_filename, 'a+') as f:
f.write(traj_json_string)
with open(track_box_filename, 'a+') as f:
f.write(bbox_json_string)
num_len = len(self.trajectories)
for i_traj in self.trajectories:
traj_json_string = json.dumps(self.trajectories[i_traj]._asdict(), cls=NumpyEncoder)
with open(trajectories_filename, 'a+') as f:
f.write(traj_json_string)
f.write('\n')