test_embedding_shape / snnTracker /test_snntracker.py
zzzzzeee's picture
Upload 28 files
9fa5305 verified
# -*- coding: utf-8 -*-
# @Time : 2024/12/05 20:17
# @Author : Yajing Zheng
# @Email: [email protected]
# @File : test_snntracker.py
import os, sys
sys.path.append("..")
import path
import numpy as np
from spkData.load_dat import data_parameter_dict, SpikeStream
from pprint import pprint
import torch
from spkProc.tracking.snn_tracker import SNNTracker
from utils import vis_trajectory
from visualization.get_video import obtain_mot_video
import cv2
from tracking_mot import TrackingMetrics
from visualization.get_video import obtain_detection_video
# change the path to where you put the datasets
test_scene = ['spike59', 'rotTrans', 'cplCam', 'cpl1', 'badminton', 'ball']
# data_filename = 'motVidarReal2020/rotTrans'
scene_idx = 2
data_filename = 'motVidarReal2020/' + test_scene[scene_idx]
label_type = 'tracking'
para_dict = data_parameter_dict(data_filename, label_type)
pprint(para_dict)
vidarSpikes = SpikeStream(**para_dict)
# block_len = 2000
# spikes = vidarSpikes.get_block_spikes(begin_idx=0, block_len=block_len)
spikes = vidarSpikes.get_spike_matrix()
pprint(spikes.shape)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
calibration_time = 150
filename = path.split_path_into_pieces(data_filename)
result_filename = filename[-1] + '_snn.txt'
if not os.path.exists('results'):
os.makedirs('results')
tracking_file = os.path.join('results', result_filename)
if os.path.exists(tracking_file):
os.remove(tracking_file)
# stp_params = {'filterThr': 0.12, # filter threshold
# 'voltageMin': -10,
# 'lifThr': 3}
spike_tracker = SNNTracker(para_dict.get('spike_h'), para_dict.get('spike_w'), device, attention_size=15)
spike_tracker.object_cluster.K2 = 4
# total_spikes = spikes
# using stp filter to filter out static spikes
spike_tracker.calibrate_motion(spikes, calibration_time)
# start tracking
track_videoName = tracking_file.replace('txt', 'avi')
mov = cv2.VideoWriter(track_videoName, cv2.VideoWriter_fourcc(*'MJPG'), 30, (para_dict.get('spike_w'), para_dict.get('spike_h')))
spike_tracker.get_results(spikes[calibration_time:], tracking_file, mov, save_video=True)
data_name = test_scene[scene_idx]
trajectories_filename = os.path.join('results', data_name + '_py.json')
visTraj_filename = os.path.join('results', data_name + '.png')
spike_tracker.save_trajectory('results', data_name)
vis_trajectory(trajectories_filename, visTraj_filename, **para_dict)
# measure the multi-object tracking performance
# metrics = TrackingMetrics(tracking_file, **para_dict)
# metrics.get_results()
#
# block_len = total_spikes.shape[0]
mov.release()
cv2.destroyAllWindows()
# # visualize the tracking results to a video
# video_filename = os.path.join('results', filename[-1] + '_mot.avi')
# obtain_mot_video(spike_tracker.filterd_spikes, video_filename, tracking_file, **para_dict)
# obtain_detection_video(total_spikes, video_filename, tracking_file, evaluate_seq_len=evaluate_seq_len, **para_dict)