Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# @Time : 2024/12/05 20:17 | |
# @Author : Yajing Zheng | |
# @Email: [email protected] | |
# @File : test_snntracker.py | |
import os, sys | |
sys.path.append("..") | |
import path | |
import numpy as np | |
from spkData.load_dat import data_parameter_dict, SpikeStream | |
from pprint import pprint | |
import torch | |
from spkProc.tracking.snn_tracker import SNNTracker | |
from utils import vis_trajectory | |
from visualization.get_video import obtain_mot_video | |
import cv2 | |
from tracking_mot import TrackingMetrics | |
from visualization.get_video import obtain_detection_video | |
# change the path to where you put the datasets | |
test_scene = ['spike59', 'rotTrans', 'cplCam', 'cpl1', 'badminton', 'ball'] | |
# data_filename = 'motVidarReal2020/rotTrans' | |
scene_idx = 2 | |
data_filename = 'motVidarReal2020/' + test_scene[scene_idx] | |
label_type = 'tracking' | |
para_dict = data_parameter_dict(data_filename, label_type) | |
pprint(para_dict) | |
vidarSpikes = SpikeStream(**para_dict) | |
# block_len = 2000 | |
# spikes = vidarSpikes.get_block_spikes(begin_idx=0, block_len=block_len) | |
spikes = vidarSpikes.get_spike_matrix() | |
pprint(spikes.shape) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
calibration_time = 150 | |
filename = path.split_path_into_pieces(data_filename) | |
result_filename = filename[-1] + '_snn.txt' | |
if not os.path.exists('results'): | |
os.makedirs('results') | |
tracking_file = os.path.join('results', result_filename) | |
if os.path.exists(tracking_file): | |
os.remove(tracking_file) | |
# stp_params = {'filterThr': 0.12, # filter threshold | |
# 'voltageMin': -10, | |
# 'lifThr': 3} | |
spike_tracker = SNNTracker(para_dict.get('spike_h'), para_dict.get('spike_w'), device, attention_size=15) | |
spike_tracker.object_cluster.K2 = 4 | |
# total_spikes = spikes | |
# using stp filter to filter out static spikes | |
spike_tracker.calibrate_motion(spikes, calibration_time) | |
# start tracking | |
track_videoName = tracking_file.replace('txt', 'avi') | |
mov = cv2.VideoWriter(track_videoName, cv2.VideoWriter_fourcc(*'MJPG'), 30, (para_dict.get('spike_w'), para_dict.get('spike_h'))) | |
spike_tracker.get_results(spikes[calibration_time:], tracking_file, mov, save_video=True) | |
data_name = test_scene[scene_idx] | |
trajectories_filename = os.path.join('results', data_name + '_py.json') | |
visTraj_filename = os.path.join('results', data_name + '.png') | |
spike_tracker.save_trajectory('results', data_name) | |
vis_trajectory(trajectories_filename, visTraj_filename, **para_dict) | |
# measure the multi-object tracking performance | |
# metrics = TrackingMetrics(tracking_file, **para_dict) | |
# metrics.get_results() | |
# | |
# block_len = total_spikes.shape[0] | |
mov.release() | |
cv2.destroyAllWindows() | |
# # visualize the tracking results to a video | |
# video_filename = os.path.join('results', filename[-1] + '_mot.avi') | |
# obtain_mot_video(spike_tracker.filterd_spikes, video_filename, tracking_file, **para_dict) | |
# obtain_detection_video(total_spikes, video_filename, tracking_file, evaluate_seq_len=evaluate_seq_len, **para_dict) |