# -*- coding: utf-8 -*- # @Time : 2023/7/16 20:23 # @Author : Yajing Zheng # @Email: yj.zheng@pku.edu.cn # @File : snn_tracker.py import os, sys sys.path.append('../..') import time import numpy as np import torch from spkProc.filters.stp_filters_torch import STPFilter # from filters import stpFilter from spkProc.detection.attention_select import SaccadeInput from spkProc.motion.motion_detection import motion_estimation from spkProc.detection.stdp_clustering import stdp_cluster from utils import NumpyEncoder from collections import namedtuple import json import cv2 from tqdm import tqdm trajectories = namedtuple('trajectories', ['id', 'x', 'y', 't', 'color']) class SNNTracker: def __init__(self, spike_h, spike_w, device, attention_size=20, diff_time=1, **STPargs): self.spike_h = spike_h self.spike_w = spike_w self.device = device # self.stp_filter = STPFilter(spike_h, spike_w, device) if STPargs is not None: self.stp_filter = STPFilter(spike_h, spike_w, device, diff_time, **STPargs) else: self.stp_filter = STPFilter(spike_h, spike_w, device, diff_time) # self.stp_filter = stpFilter() self.attention_size = attention_size self.object_detection = SaccadeInput(spike_h, spike_w, box_size=self.attention_size, device=device) from tensorboardX import SummaryWriter logger = SummaryWriter(log_dir='data/log_pkuvidar') self.motion_estimator = motion_estimation(spike_h, spike_w, device, logger=logger) # gpu_tracker.track() # run function between the code line where uses GPU self.object_cluster = stdp_cluster(spike_h, spike_w, box_size=self.attention_size, device=device) # self.timestamps = spikes.shape[0] # self.filterd_spikes = np.zeros([self.timestamps, self.spike_h, self.spike_w], np.uint8) self.calibration_time = 150 self.timestamps = 0 self.trajectories = {} self.filterd_spikes = [] def calibrate_motion(self, spikes, calibration_time=None): if calibration_time is None: calibration_time = self.calibration_time else: self.calibration_time = calibration_time print('begin calibrate..') for t in range(calibration_time): input_spk = torch.from_numpy(spikes[t, :, :]).to(self.device) self.stp_filter.update_dynamics(t, input_spk) self.timestamps += 1 def get_results(self, spikes, res_filepath, mov_writer=None, save_video=False): result_file = open(res_filepath, 'a+') timestamps = spikes.shape[0] total_time = 0 predict_kwargs = {'spike_h': self.spike_h, 'spike_w': self.spike_w, 'device': self.device} for t in tqdm(range(timestamps), desc=f'Saving tracking results to {str(result_file)}'): try: input_spk = torch.from_numpy(spikes[t, :, :]).to(self.device) self.stp_filter.update_dynamics(self.timestamps, input_spk) self.stp_filter.local_connect(self.stp_filter.filter_spk) # self.filterd_spikes[t, :, :] = self.stp_filter.lif_spk.cpu().detach().numpy() self.object_detection.update_dnf(self.stp_filter.lif_spk) attentionBox, attentionInput = self.object_detection.get_attention_location(self.stp_filter.lif_spk) # attentionInput = attentionInput.to(self.device) num_box = attentionBox.shape[0] self.motion_estimator.stdp_tracking(self.stp_filter.lif_spk) motion_id, motion_vector, _ = self.motion_estimator.local_wta(self.stp_filter.lif_spk, self.timestamps, visualize=True) # gpu_tracker.track() # run function between the code line where uses GPU predict_fire, sw, bw = self.object_cluster.update_weight(attentionInput) predict_object = self.object_cluster.detect_object(predict_fire, attentionBox, motion_id, motion_vector, **predict_kwargs) # visualize_weights(sw, 'before update tracks', t) sw, bw = self.object_cluster.update_tracks(predict_object, sw, bw, self.timestamps) self.object_cluster.synaptic_weight = sw.detach().clone() self.object_cluster.bias_weight = bw.detach().clone() dets = torch.zeros((num_box, 6), dtype=torch.int) for i_box, bbox in enumerate(attentionBox): dets[i_box, :] = torch.tensor([bbox[0], bbox[1], bbox[2], bbox[3], 1, 1]) track_ids = [] if save_video: track_frame = self.stp_filter.lif_spk.cpu().numpy() track_frame = (track_frame * 255).astype(np.uint8) # track_frame = np.transpose(track_frame, (1, 2, 0)) # track_frame = np.tile(track_frame, (3, 1, 1)) # track_frame = np.squeeze(track_frame) track_frame = cv2.cvtColor(track_frame, cv2.COLOR_GRAY2BGR) for i_box in range(attentionBox.shape[0]): tmp_box = attentionBox[i_box, :] cv2.rectangle(track_frame, (int(tmp_box[1]), int(tmp_box[0])), (int(tmp_box[3]), int(tmp_box[2])), (int(0), int(0), int(255)), 2) for i_box in range(self.object_cluster.K2): if self.object_cluster.tracks[i_box].visible == 1: tmp_box = self.object_cluster.tracks[i_box].bbox.numpy() pred_box = self.object_cluster.tracks[i_box].predbox.numpy() id = self.object_cluster.tracks[i_box].id color = self.object_cluster.tracks[i_box].color # update the trajectories mid_y = (tmp_box[0, 0] + tmp_box[0, 2]) / 2 # height mid_x = (tmp_box[0, 1] + tmp_box[0, 3]) / 2 # width box_w = int(tmp_box[0, 3] - tmp_box[0, 1]) box_h = int(tmp_box[0,2] - tmp_box[0, 0]) print('%d,%d,%.2f,%.2f,%.2f,%.2f,1,-1,-1,-1' % ( self.timestamps, id, tmp_box[0, 1], tmp_box[0, 0], box_w, box_h), file=result_file) if id not in self.trajectories: self.trajectories[id] = trajectories(int(id), [], [], [], 255 * np.random.rand(1, 3)) self.trajectories[id].x.append(mid_x) self.trajectories[id].y.append(mid_y) self.trajectories[id].t.append(self.timestamps) else: self.trajectories[id].x.append(mid_x) self.trajectories[id].y.append(mid_y) self.trajectories[id].t.append(self.timestamps) # the detection results if save_video: cv2.rectangle(track_frame, (int(tmp_box[0, 1]), int(tmp_box[0, 0])), (int(tmp_box[0, 3]), int(tmp_box[0, 2])), (int(color[0, 0]), int(color[0, 1]), int(color[0, 2])), 2) # # the predicted results # cv2.rectangle(track_frame, (int(pred_box[0, 1]), int(pred_box[0, 0])), # (int(pred_box[0, 3]), int(pred_box[0, 2])), (int(0), int(0), int(255)), 2) # the label box cv2.rectangle(track_frame, (int(tmp_box[0, 1]), int(tmp_box[0, 0] - 35)), (int(tmp_box[0, 1] + 60), int(tmp_box[0, 0])), (int(color[0, 0]), int(color[0, 1]), int(color[0, 2])), -1) if self.object_cluster.tracks[i_box].unvisible_count > 0: show_text = 'predict' + str(id) else: show_text = 'object' + str(id) cv2.putText(track_frame, show_text, (int(tmp_box[0, 1]), int(tmp_box[0, 0] - 10)), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) if save_video: cv2.putText(track_frame, str(int(self.timestamps)), (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 255), 2) mov_writer.write(track_frame) self.timestamps += 1 except RuntimeError as exception: if "out of memory" in str(exception): print('WARNING: out of memory') if hasattr(torch.cuda, 'empty_cache'): torch.cuda.empty_cache() else: raise exception print('Total tracking took: %.3f seconds for %d timestamps spikes' % (total_time, self.timestamps - self.calibration_time)) # if save_video: # mov_writer.release() # cv2.destroyAllWindows() result_file.close() def save_trajectory(self, results_dir, data_name): trajectories_filename = os.path.join(results_dir, data_name + '_py.json') mat_trajectories_filename = 'results/' + data_name + '.json' track_box_filename = 'results/' + data_name + '_bbox.json' if os.path.exists(trajectories_filename): os.remove(trajectories_filename) if os.path.exists(mat_trajectories_filename): os.remove(mat_trajectories_filename) if os.path.exists(track_box_filename): os.remove(track_box_filename) for i_traj in range(self.object_cluster.K2): tmp_traj = self.object_cluster.trajectories[i_traj] tmp_bbox = self.object_cluster.tracks_bbox[i_traj] traj_json_string = json.dumps(tmp_traj._asdict(), cls=NumpyEncoder) bbox_json_string = json.dumps(tmp_bbox._asdict(), cls=NumpyEncoder) with open(mat_trajectories_filename, 'a+') as f: f.write(traj_json_string) with open(track_box_filename, 'a+') as f: f.write(bbox_json_string) num_len = len(self.trajectories) for i_traj in self.trajectories: traj_json_string = json.dumps(self.trajectories[i_traj]._asdict(), cls=NumpyEncoder) with open(trajectories_filename, 'a+') as f: f.write(traj_json_string) f.write('\n')