File size: 2,858 Bytes
9fa5305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# -*- coding: utf-8 -*-
# @Time : 2024/12/05 20:17
# @Author : Yajing Zheng
# @Email: [email protected]
# @File : test_snntracker.py
import os, sys
sys.path.append("..")
import path

import numpy as np
from spkData.load_dat import data_parameter_dict, SpikeStream
from pprint import pprint
import torch
from spkProc.tracking.snn_tracker import SNNTracker
from utils import vis_trajectory
from visualization.get_video import obtain_mot_video
import cv2
# from tracking_mot import TrackingMetrics

from visualization.get_video import obtain_detection_video

# change the path to where you put the datasets
test_scene = "0"
# data_filename = 'motVidarReal2020/rotTrans'
data_filename = test_scene
label_type = 'tracking'
para_dict = data_parameter_dict(data_filename, label_type)
pprint(para_dict)
vidarSpikes = SpikeStream(**para_dict)

# block_len = 2000
# spikes = vidarSpikes.get_block_spikes(begin_idx=0, block_len=block_len)
spikes = vidarSpikes.get_spike_matrix()
pprint(spikes.shape)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

calibration_time = 150
filename = path.split_path_into_pieces(data_filename)
result_filename = filename[-1] + '_snn.txt'
if not os.path.exists('results'):
    os.makedirs('results')
tracking_file = os.path.join('results', result_filename)
if os.path.exists(tracking_file):
    os.remove(tracking_file)

# stp_params = {'filterThr': 0.12,  # filter threshold
#               'voltageMin': -10,
#               'lifThr': 3}
spike_tracker = SNNTracker(para_dict.get('spike_h'), para_dict.get('spike_w'), device, attention_size=15)
spike_tracker.object_cluster.K2 = 4
# total_spikes = spikes

# using stp filter to filter out static spikes
spike_tracker.calibrate_motion(spikes, calibration_time)
# start tracking
track_videoName = tracking_file.replace('txt', 'avi')
mov = cv2.VideoWriter(track_videoName, cv2.VideoWriter_fourcc(*'MJPG'), 30, (para_dict.get('spike_w'), para_dict.get('spike_h')))
spike_tracker.get_results(spikes[calibration_time:], tracking_file, mov, save_video=True)

data_name = test_scene
trajectories_filename = os.path.join('results', data_name + '_py.json')
visTraj_filename = os.path.join('results', data_name + '.png')

spike_tracker.save_trajectory('results', data_name)
vis_trajectory(trajectories_filename, visTraj_filename, **para_dict)
# measure the multi-object tracking performance
# metrics = TrackingMetrics(tracking_file, **para_dict)
# metrics.get_results()
#
# block_len = total_spikes.shape[0]
mov.release()
cv2.destroyAllWindows()
# # visualize the tracking results to a video
# video_filename = os.path.join('results', filename[-1] + '_mot.avi')
# obtain_mot_video(spike_tracker.filterd_spikes, video_filename, tracking_file, **para_dict)
# obtain_detection_video(total_spikes, video_filename, tracking_file, evaluate_seq_len=evaluate_seq_len, **para_dict)