# -*- coding: utf-8 -*- # @Time : 2024/12/05 20:17 # @Author : Yajing Zheng # @Email: yj.zheng@pku.edu.cn # @File : test_snntracker.py import os, sys sys.path.append("..") import path import numpy as np from spkData.load_dat import data_parameter_dict, SpikeStream from pprint import pprint import torch from spkProc.tracking.snn_tracker import SNNTracker from utils import vis_trajectory from visualization.get_video import obtain_mot_video import cv2 # from tracking_mot import TrackingMetrics from visualization.get_video import obtain_detection_video # change the path to where you put the datasets test_scene = "0" # data_filename = 'motVidarReal2020/rotTrans' data_filename = test_scene label_type = 'tracking' para_dict = data_parameter_dict(data_filename, label_type) pprint(para_dict) vidarSpikes = SpikeStream(**para_dict) # block_len = 2000 # spikes = vidarSpikes.get_block_spikes(begin_idx=0, block_len=block_len) spikes = vidarSpikes.get_spike_matrix() pprint(spikes.shape) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") calibration_time = 150 filename = path.split_path_into_pieces(data_filename) result_filename = filename[-1] + '_snn.txt' if not os.path.exists('results'): os.makedirs('results') tracking_file = os.path.join('results', result_filename) if os.path.exists(tracking_file): os.remove(tracking_file) # stp_params = {'filterThr': 0.12, # filter threshold # 'voltageMin': -10, # 'lifThr': 3} spike_tracker = SNNTracker(para_dict.get('spike_h'), para_dict.get('spike_w'), device, attention_size=15) spike_tracker.object_cluster.K2 = 4 # total_spikes = spikes # using stp filter to filter out static spikes spike_tracker.calibrate_motion(spikes, calibration_time) # start tracking track_videoName = tracking_file.replace('txt', 'avi') mov = cv2.VideoWriter(track_videoName, cv2.VideoWriter_fourcc(*'MJPG'), 30, (para_dict.get('spike_w'), para_dict.get('spike_h'))) spike_tracker.get_results(spikes[calibration_time:], tracking_file, mov, save_video=True) data_name = test_scene trajectories_filename = os.path.join('results', data_name + '_py.json') visTraj_filename = os.path.join('results', data_name + '.png') spike_tracker.save_trajectory('results', data_name) vis_trajectory(trajectories_filename, visTraj_filename, **para_dict) # measure the multi-object tracking performance # metrics = TrackingMetrics(tracking_file, **para_dict) # metrics.get_results() # # block_len = total_spikes.shape[0] mov.release() cv2.destroyAllWindows() # # visualize the tracking results to a video # video_filename = os.path.join('results', filename[-1] + '_mot.avi') # obtain_mot_video(spike_tracker.filterd_spikes, video_filename, tracking_file, **para_dict) # obtain_detection_video(total_spikes, video_filename, tracking_file, evaluate_seq_len=evaluate_seq_len, **para_dict)