Spaces:
Running
Running
Upload 28 files
Browse files- .gitattributes +12 -0
- snnTracker/0.dat +3 -0
- snnTracker/1.dat +3 -0
- snnTracker/2.dat +3 -0
- snnTracker/3.dat +3 -0
- snnTracker/4.dat +3 -0
- snnTracker/5.dat +3 -0
- snnTracker/6.dat +3 -0
- snnTracker/7.dat +3 -0
- snnTracker/8.dat +3 -0
- snnTracker/9.dat +3 -0
- snnTracker/datasets/0/config.yaml +13 -0
- snnTracker/driving_0_snntracker.avi +3 -0
- snnTracker/driving_0_tfi.avi +3 -0
- snnTracker/path.py +52 -0
- snnTracker/spkData/load_dat.py +203 -0
- snnTracker/spkProc/detection/attention_select.py +132 -0
- snnTracker/spkProc/detection/motion_clustering.py +101 -0
- snnTracker/spkProc/detection/stdp_clustering.py +397 -0
- snnTracker/spkProc/filters/stp_filters_torch.py +170 -0
- snnTracker/spkProc/motion/motion_detection.py +347 -0
- snnTracker/spkProc/tracking/snn_tracker.py +227 -0
- snnTracker/test_motion_detection.py +201 -0
- snnTracker/test_snntracker copy.py +77 -0
- snnTracker/test_snntracker.py +78 -0
- snnTracker/utils.py +207 -0
- snnTracker/visualization/get_image.py +121 -0
- snnTracker/visualization/get_video.py +245 -0
- snnTracker/visualization/optical_flow_visualization.py +272 -0
.gitattributes
CHANGED
@@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
snnTracker/0.dat filter=lfs diff=lfs merge=lfs -text
|
37 |
+
snnTracker/1.dat filter=lfs diff=lfs merge=lfs -text
|
38 |
+
snnTracker/2.dat filter=lfs diff=lfs merge=lfs -text
|
39 |
+
snnTracker/3.dat filter=lfs diff=lfs merge=lfs -text
|
40 |
+
snnTracker/4.dat filter=lfs diff=lfs merge=lfs -text
|
41 |
+
snnTracker/5.dat filter=lfs diff=lfs merge=lfs -text
|
42 |
+
snnTracker/6.dat filter=lfs diff=lfs merge=lfs -text
|
43 |
+
snnTracker/7.dat filter=lfs diff=lfs merge=lfs -text
|
44 |
+
snnTracker/8.dat filter=lfs diff=lfs merge=lfs -text
|
45 |
+
snnTracker/9.dat filter=lfs diff=lfs merge=lfs -text
|
46 |
+
snnTracker/driving_0_snntracker.avi filter=lfs diff=lfs merge=lfs -text
|
47 |
+
snnTracker/driving_0_tfi.avi filter=lfs diff=lfs merge=lfs -text
|
snnTracker/0.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d325c5ecb53e64a6429c753df5e0c3b46deaf3bce97f445499c9dfb64b3db4d
|
3 |
+
size 5000000
|
snnTracker/1.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba866d8427b0baa97a0c09bab6055edceaac2dfff6fcba0580f21973628e5d1a
|
3 |
+
size 5000000
|
snnTracker/2.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8a00b47dc8e9e84f902f159cd1f0ab387440091b470e4da9503e18e3bae4d354
|
3 |
+
size 5000000
|
snnTracker/3.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:47a782e5eea659f9ddd8ad0e20b688e1e69ed9bf589f99989586397d9e2bd825
|
3 |
+
size 5000000
|
snnTracker/4.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:02b795f9ca1f3affc59d59e280051bfd50f24c3d9c1c214ab8a680c9fdc2175a
|
3 |
+
size 5000000
|
snnTracker/5.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:68ebdd547fc9c36401a3033701f725d9552de1ae235ea49f9bf04a1afad4a331
|
3 |
+
size 5000000
|
snnTracker/6.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e5d0349fab1c1da0d3b3559c8bf30795f2488ec7caa182126525c5eb49f52d10
|
3 |
+
size 5000000
|
snnTracker/7.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a36642ec9cc8ef292cb1223d8cf4c263dde4203da83f0e55179e3570a46754b3
|
3 |
+
size 5000000
|
snnTracker/8.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cea9cb59395d668eab1c18b28cca3821e63f6ed2f097e5a120e67253ee78080a
|
3 |
+
size 5000000
|
snnTracker/9.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50ef6ee3758465fba69b5c453e76cae4a7c805805e646445e8877cc886849ce6
|
3 |
+
size 5000000
|
snnTracker/datasets/0/config.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# 基本配置
|
3 |
+
spike_h: 250 # 脉冲数据高度
|
4 |
+
spike_w: 400 # 脉冲数据宽度
|
5 |
+
is_labeled: false # 是否有标注数据
|
6 |
+
|
7 |
+
# 数据标识符
|
8 |
+
data_field_identifier: '' # 数据文件标识符
|
9 |
+
label_field_identifier: '' # 标注文件标识符
|
10 |
+
|
11 |
+
# 标注数据配置
|
12 |
+
labeled_data_type: 'tracking' # 标注数据类型
|
13 |
+
labeled_data_suffix: 'txt' # 标注文件后缀
|
snnTracker/driving_0_snntracker.avi
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f738eb42502b754cadf26539f749fe0281ba023cb139f32dff7af2a0a0cadf99
|
3 |
+
size 578258
|
snnTracker/driving_0_tfi.avi
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:62d6381f1f002c2f68b0e935d69db5ca2545ceb68566c6e6706f613cba6d2af4
|
3 |
+
size 9748482
|
snnTracker/path.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2023/7/16 20:19
|
3 |
+
# @Author : Yajing Zheng
|
4 |
+
# @Email: [email protected]
|
5 |
+
# @File : path.py
|
6 |
+
# here put the import lib
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
def seek_file(search_dirs, filename):
|
11 |
+
search_dir_split = split_path_into_pieces(search_dirs)
|
12 |
+
dir_level = len(search_dir_split)
|
13 |
+
|
14 |
+
for i_dir in range(0, dir_level):
|
15 |
+
if i_dir > 0:
|
16 |
+
search_dir_split.pop(-1)
|
17 |
+
# search_dir = os.path.join(str(search_dir_split[0:-i_dir]))
|
18 |
+
search_dir = os.path.join(*search_dir_split)
|
19 |
+
for root, dirs, files in os.walk(search_dir):
|
20 |
+
if filename in files:
|
21 |
+
print('{0}/{1}'.format(root, filename))
|
22 |
+
filepath = os.path.join(root, filename)
|
23 |
+
return filepath
|
24 |
+
|
25 |
+
|
26 |
+
def split_path_into_pieces(path: str):
|
27 |
+
pieces = []
|
28 |
+
if path[-1] == '/':
|
29 |
+
path = path[0:-1]
|
30 |
+
|
31 |
+
while True:
|
32 |
+
splits = os.path.split(path)
|
33 |
+
if splits[0] == '':
|
34 |
+
pieces.insert(0, splits[-1])
|
35 |
+
break
|
36 |
+
if splits[-1] == '':
|
37 |
+
pieces.insert(0, splits[0])
|
38 |
+
break
|
39 |
+
pieces.insert(0, splits[-1])
|
40 |
+
path = splits[0]
|
41 |
+
|
42 |
+
return pieces
|
43 |
+
|
44 |
+
def replace_identifier(path: list, src: str, dst: str):
|
45 |
+
new_path = []
|
46 |
+
for piece in path:
|
47 |
+
added_piece = piece
|
48 |
+
if piece == src:
|
49 |
+
added_piece = dst
|
50 |
+
new_path.append(added_piece)
|
51 |
+
|
52 |
+
return new_path
|
snnTracker/spkData/load_dat.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2023/7/16 20:13
|
3 |
+
# @Author : Yajing Zheng
|
4 |
+
# @Email: [email protected]
|
5 |
+
# @File : load_dat.py
|
6 |
+
import os, sys
|
7 |
+
import warnings
|
8 |
+
import glob
|
9 |
+
import yaml
|
10 |
+
import numpy as np
|
11 |
+
import path
|
12 |
+
|
13 |
+
# key-value for generate data loader according to the type of label data
|
14 |
+
LABEL_DATA_TYPE = {
|
15 |
+
'raw': 0,
|
16 |
+
'reconstruction': 1,
|
17 |
+
'optical_flow': 2,
|
18 |
+
'mono_depth_estimation': 3.1,
|
19 |
+
'stero_depth_estimation': 3.2,
|
20 |
+
'detection': 4,
|
21 |
+
'tracking': 5,
|
22 |
+
'recognition': 6
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
# generate parameters dictionary according to labeled or not
|
27 |
+
def data_parameter_dict(data_filename, label_type):
|
28 |
+
filename = path.split_path_into_pieces(data_filename)
|
29 |
+
|
30 |
+
if os.path.isabs(data_filename):
|
31 |
+
file_root = data_filename
|
32 |
+
if os.path.isdir(file_root):
|
33 |
+
search_root = file_root
|
34 |
+
else:
|
35 |
+
search_root = '\\'.join(filename[0:-1])
|
36 |
+
config_filename = path.seek_file(search_root, 'config.yaml')
|
37 |
+
else:
|
38 |
+
file_root = os.path.join('', 'datasets', *filename)
|
39 |
+
config_filename = os.path.join('', 'datasets', filename[0], 'config.yaml')
|
40 |
+
|
41 |
+
try:
|
42 |
+
with open(config_filename, 'r', encoding='utf-8') as fin:
|
43 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
44 |
+
except TypeError as err:
|
45 |
+
print("Cannot find config file" + str(err))
|
46 |
+
raise err
|
47 |
+
|
48 |
+
except KeyError as exception:
|
49 |
+
print('ERROR! Task name does not exist')
|
50 |
+
print('Task name must be in %s' % LABEL_DATA_TYPE.keys())
|
51 |
+
raise exception
|
52 |
+
|
53 |
+
is_labeled = configs.get('is_labeled')
|
54 |
+
|
55 |
+
paraDict = {'spike_h': configs.get('spike_h'), 'spike_w': configs.get('spike_w')}
|
56 |
+
paraDict['filelist'] = None
|
57 |
+
|
58 |
+
if is_labeled:
|
59 |
+
paraDict['labeled_data_type'] = configs.get('labeled_data_type')
|
60 |
+
paraDict['labeled_data_suffix'] = configs.get('labeled_data_suffix')
|
61 |
+
paraDict['label_root_list'] = None
|
62 |
+
|
63 |
+
if os.path.isdir(file_root):
|
64 |
+
filelist = sorted(glob.glob(file_root + '/*.dat'), key=os.path.getmtime)
|
65 |
+
filepath = filelist[0]
|
66 |
+
|
67 |
+
labelname = path.replace_identifier(filename, configs.get('data_field_identifier', ''),
|
68 |
+
configs.get('label_field_identifier', ''))
|
69 |
+
label_root_list = os.path.join('', 'datasets', *labelname)
|
70 |
+
paraDict['labeled_data_dir'] = sorted(glob.glob(label_root_list + '/*.' + paraDict['labeled_data_suffix']),
|
71 |
+
key=os.path.getmtime)
|
72 |
+
|
73 |
+
paraDict['filelist'] = filelist
|
74 |
+
paraDict['label_root_list'] = label_root_list
|
75 |
+
else:
|
76 |
+
filepath = glob.glob(file_root)[0]
|
77 |
+
rawname = filename[-1].replace('.dat', '')
|
78 |
+
filename.pop(-1)
|
79 |
+
filename.append(rawname)
|
80 |
+
labelname = path.replace_identifier(filename, configs.get('data_field_identifier', ''),
|
81 |
+
configs.get('label_field_identifier', ''))
|
82 |
+
label_root = os.path.join('', 'datasets', *labelname)
|
83 |
+
paraDict['labeled_data_dir'] = glob.glob(label_root + '.' + paraDict['labeled_data_suffix'])[0]
|
84 |
+
else:
|
85 |
+
filepath = file_root
|
86 |
+
|
87 |
+
paraDict['filepath'] = filepath
|
88 |
+
|
89 |
+
return paraDict
|
90 |
+
|
91 |
+
|
92 |
+
class SpikeStream:
|
93 |
+
def __init__(self, **kwargs):
|
94 |
+
|
95 |
+
self.SpikeMatrix = None
|
96 |
+
self.filename = kwargs.get('filepath')
|
97 |
+
if os.path.splitext(self.filename)[-1][1:] != 'dat':
|
98 |
+
self.filename = self.filename + '.dat'
|
99 |
+
self.spike_w = kwargs.get('spike_w')
|
100 |
+
self.spike_h = kwargs.get('spike_h')
|
101 |
+
if 'print_dat_detail' not in kwargs:
|
102 |
+
self.print_dat_detail = True
|
103 |
+
else:
|
104 |
+
self.print_dat_detail = kwargs.get('print_dat_detail')
|
105 |
+
|
106 |
+
def get_spike_matrix(self, flipud=True, with_head=False):
|
107 |
+
|
108 |
+
file_reader = open(self.filename, 'rb')
|
109 |
+
video_seq = file_reader.read()
|
110 |
+
video_seq = np.frombuffer(video_seq, 'b')
|
111 |
+
|
112 |
+
video_seq = np.array(video_seq).astype(np.byte)
|
113 |
+
if self.print_dat_detail:
|
114 |
+
print(video_seq)
|
115 |
+
if with_head:
|
116 |
+
decode_width = 416
|
117 |
+
else:
|
118 |
+
decode_width = self.spike_w
|
119 |
+
# img_size = self.spike_height * self.spike_width
|
120 |
+
img_size = self.spike_h * decode_width
|
121 |
+
img_num = len(video_seq) // (img_size // 8)
|
122 |
+
|
123 |
+
if self.print_dat_detail:
|
124 |
+
print('loading total spikes from dat file -- spatial resolution: %d x %d, total timestamp: %d' %
|
125 |
+
(decode_width, self.spike_h, img_num))
|
126 |
+
|
127 |
+
# SpikeMatrix = np.zeros([img_num, self.spike_h, self.spike_width], np.byte)
|
128 |
+
|
129 |
+
pix_id = np.arange(0, img_num * self.spike_h * decode_width)
|
130 |
+
pix_id = np.reshape(pix_id, (img_num, self.spike_h, decode_width))
|
131 |
+
comparator = np.left_shift(1, np.mod(pix_id, 8))
|
132 |
+
byte_id = pix_id // 8
|
133 |
+
|
134 |
+
data = video_seq[byte_id]
|
135 |
+
result = np.bitwise_and(data, comparator)
|
136 |
+
tmp_matrix = (result == comparator)
|
137 |
+
|
138 |
+
# if with head, delete them
|
139 |
+
if with_head:
|
140 |
+
delete_indx = np.arange(400, 416)
|
141 |
+
tmp_matrix = np.delete(tmp_matrix, delete_indx, 2)
|
142 |
+
|
143 |
+
if flipud:
|
144 |
+
self.SpikeMatrix = tmp_matrix[:, ::-1, :]
|
145 |
+
else:
|
146 |
+
self.SpikeMatrix = tmp_matrix
|
147 |
+
|
148 |
+
file_reader.close()
|
149 |
+
self.SpikeMatrix = self.SpikeMatrix.astype(np.byte)
|
150 |
+
return self.SpikeMatrix
|
151 |
+
|
152 |
+
# return spikes with specified length and begin index
|
153 |
+
|
154 |
+
def get_block_spikes(self, begin_idx, block_len=1, flipud=True, with_head=False):
|
155 |
+
|
156 |
+
file_reader = open(self.filename, 'rb')
|
157 |
+
video_seq = file_reader.read()
|
158 |
+
video_seq = np.frombuffer(video_seq, 'b')
|
159 |
+
|
160 |
+
video_seq = np.array(video_seq).astype(np.uint8)
|
161 |
+
|
162 |
+
if with_head:
|
163 |
+
decode_width = 416
|
164 |
+
else:
|
165 |
+
decode_width = self.spike_w
|
166 |
+
# img_size = self.spike_height * self.spike_width
|
167 |
+
img_size = self.spike_h * decode_width
|
168 |
+
img_num = len(video_seq) // (img_size // 8)
|
169 |
+
|
170 |
+
end_idx = begin_idx + block_len
|
171 |
+
if end_idx > img_num:
|
172 |
+
warnings.warn("block_len exceeding upper limit! Zeros will be padded in the end. ", ResourceWarning)
|
173 |
+
end_idx = img_num
|
174 |
+
|
175 |
+
if self.print_dat_detail:
|
176 |
+
print(
|
177 |
+
'loading total spikes from dat file -- spatial resolution: %d x %d, begin index: %d total timestamp: %d' %
|
178 |
+
(decode_width, self.spike_h, begin_idx, block_len))
|
179 |
+
|
180 |
+
pix_id = np.arange(0, block_len * self.spike_h * decode_width)
|
181 |
+
pix_id = np.reshape(pix_id, (block_len, self.spike_h, decode_width))
|
182 |
+
comparator = np.left_shift(1, np.mod(pix_id, 8))
|
183 |
+
byte_id = pix_id // 8
|
184 |
+
id_start = begin_idx * img_size // 8
|
185 |
+
id_end = id_start + block_len * img_size // 8
|
186 |
+
data = video_seq[id_start:id_end]
|
187 |
+
data_frame = data[byte_id]
|
188 |
+
result = np.bitwise_and(data_frame, comparator)
|
189 |
+
tmp_matrix = (result == comparator)
|
190 |
+
|
191 |
+
# if with head, delete them
|
192 |
+
if with_head:
|
193 |
+
delete_indx = np.arange(400, 416)
|
194 |
+
tmp_matrix = np.delete(tmp_matrix, delete_indx, 2)
|
195 |
+
|
196 |
+
if flipud:
|
197 |
+
self.SpikeMatrix = tmp_matrix[:, ::-1, :]
|
198 |
+
else:
|
199 |
+
self.SpikeMatrix = tmp_matrix
|
200 |
+
|
201 |
+
file_reader.close()
|
202 |
+
self.SpikeMatrix = self.SpikeMatrix.astype(np.byte)
|
203 |
+
return self.SpikeMatrix
|
snnTracker/spkProc/detection/attention_select.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import skimage.morphology as smor
|
2 |
+
from skimage.feature import peak_local_max
|
3 |
+
from skimage.morphology import erosion
|
4 |
+
from skimage.measure import label, regionprops
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torchvision.transforms import Resize
|
8 |
+
|
9 |
+
|
10 |
+
# obtain 2D gaussian filter
|
11 |
+
def get_kernel(filter_size, sigma):
|
12 |
+
|
13 |
+
assert (filter_size + 1) % 2 == 0, '2D filter size must be odd number!'
|
14 |
+
g = np.zeros((filter_size, filter_size), dtype=np.float32)
|
15 |
+
half_width = int((filter_size - 1) / 2)
|
16 |
+
# center location
|
17 |
+
|
18 |
+
xc = (filter_size + 1) / 2
|
19 |
+
yc = (filter_size + 1) / 2
|
20 |
+
for i in range(-half_width, half_width + 1, 1):
|
21 |
+
for j in range(-half_width, half_width + 1, 1):
|
22 |
+
x = int(xc + i)
|
23 |
+
y = int(yc + j)
|
24 |
+
g[y - 1, x - 1] = np.exp(- (i ** 2 + j ** 2) / 2 / sigma / sigma)
|
25 |
+
|
26 |
+
g = (g - g.min()) / (g.max() - g.min())
|
27 |
+
return g
|
28 |
+
|
29 |
+
|
30 |
+
# detect moving connected regions
|
31 |
+
class SaccadeInput:
|
32 |
+
|
33 |
+
def __init__(self, spike_h, spike_w, box_size, device, attentionThr=None, extend_edge=None):
|
34 |
+
|
35 |
+
self.spike_h = spike_h
|
36 |
+
self.spike_w = spike_w
|
37 |
+
self.device = device
|
38 |
+
|
39 |
+
self.U = torch.zeros(self.spike_h, self.spike_w, dtype=torch.float32)
|
40 |
+
self.tau_u = 0.5
|
41 |
+
self.global_inih = 0.01
|
42 |
+
self.box_width = box_size # attention box width
|
43 |
+
self.Jxx_size = self.box_width * 2 + 1
|
44 |
+
self.Jxx = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(self.Jxx_size, self.Jxx_size),
|
45 |
+
padding=(self.box_width, self.box_width), bias=False)
|
46 |
+
|
47 |
+
tmp_filter = get_kernel(self.Jxx_size, round(self.box_width / 2) + 1)
|
48 |
+
tmp_filter = tmp_filter.reshape((1, 1, self.Jxx_size, self.Jxx_size))
|
49 |
+
self.Jxx.weight.data = torch.from_numpy(tmp_filter)
|
50 |
+
self.resizer = Resize((self.Jxx_size, self.Jxx_size))
|
51 |
+
|
52 |
+
self.U = self.U.to(self.device)
|
53 |
+
self.Jxx = self.Jxx.to(self.device)
|
54 |
+
|
55 |
+
if attentionThr is not None:
|
56 |
+
self.attentionThr = attentionThr
|
57 |
+
else:
|
58 |
+
self.attentionThr = 40
|
59 |
+
if extend_edge is not None:
|
60 |
+
self.extend_edge = extend_edge
|
61 |
+
else:
|
62 |
+
self.extend_edge = 7
|
63 |
+
# self.extend_edge = 1
|
64 |
+
self.peak_width = int(self.extend_edge)
|
65 |
+
|
66 |
+
def update_dnf(self, spike):
|
67 |
+
inputSpk = torch.reshape(spike, (1, 1, self.spike_h, self.spike_w)).float()
|
68 |
+
|
69 |
+
maxU = torch.relu(self.U)
|
70 |
+
squareU = torch.square(maxU)
|
71 |
+
r = squareU / (1 + self.global_inih * torch.sum(squareU))
|
72 |
+
conv_fired = self.Jxx(inputSpk)
|
73 |
+
conv_fired = torch.squeeze(conv_fired).to(self.device)
|
74 |
+
du = conv_fired - self.U
|
75 |
+
|
76 |
+
r = torch.reshape(r, (1, 1, self.spike_h, self.spike_w))
|
77 |
+
conv_r = self.Jxx(r)
|
78 |
+
conv_r = torch.squeeze(conv_r).to(self.device)
|
79 |
+
du = conv_r + du
|
80 |
+
self.U += (du * self.tau_u).detach()
|
81 |
+
|
82 |
+
del inputSpk, maxU, squareU, r, conv_r, conv_fired, du
|
83 |
+
|
84 |
+
def get_attention_location(self, spikes):
|
85 |
+
|
86 |
+
tmpU = torch.relu(self.U - self.attentionThr)
|
87 |
+
tmpU = tmpU.cpu()
|
88 |
+
tmpU = tmpU.detach().numpy()
|
89 |
+
dilated_u = erosion(tmpU, smor.square(self.peak_width))
|
90 |
+
peak_cord = peak_local_max(dilated_u, min_distance=self.box_width)
|
91 |
+
num_max = len(peak_cord)
|
92 |
+
# print('detect %d attention location' % num_max)
|
93 |
+
dilated_u[dilated_u > 1] = 1
|
94 |
+
dilated_u[dilated_u < 1] = 0
|
95 |
+
region_labels = label(dilated_u)
|
96 |
+
regions = regionprops(region_labels)
|
97 |
+
num_box = len(regions)
|
98 |
+
|
99 |
+
attentionBox = torch.zeros((num_box, 4), dtype=torch.int)
|
100 |
+
attentionInput = torch.zeros(self.Jxx_size + 4, self.Jxx_size, num_box)
|
101 |
+
|
102 |
+
for region, iBox in zip(regions, range(num_box)):
|
103 |
+
minr, minc, maxr, maxc = region.bbox
|
104 |
+
beginX = minr - self.extend_edge >= 0 and minr - self.extend_edge or 0
|
105 |
+
beginY = minc - self.extend_edge >= 0 and minc - self.extend_edge or 0
|
106 |
+
endX = maxr + self.extend_edge < self.spike_h and maxr + self.extend_edge or self.spike_h - 1
|
107 |
+
endY = maxc + self.extend_edge < self.spike_w and maxc + self.extend_edge or self.spike_w - 1
|
108 |
+
|
109 |
+
attentionBox[iBox, :] = torch.tensor([beginX, beginY, endX, endY])
|
110 |
+
attentionI = torch.unsqueeze(spikes[beginX:endX + 1, beginY:endY + 1], dim=0)
|
111 |
+
attentionI = self.resizer.forward(attentionI)
|
112 |
+
fire_index = torch.where(attentionI > 0.9)
|
113 |
+
attentionI2 = torch.zeros_like(attentionI)
|
114 |
+
attentionI2[0, fire_index[1], fire_index[2]] = 1
|
115 |
+
attentionInput[:-4, :, iBox] = torch.squeeze(attentionI2).detach().clone()
|
116 |
+
tmp_spk = bin(beginX + 1)
|
117 |
+
tmp_spk = tmp_spk[2:].zfill(self.box_width)
|
118 |
+
attentionInput[-4, :-1, iBox] = torch.from_numpy(np.tile(np.array(list(tmp_spk), dtype=np.float32), (1, 2)))
|
119 |
+
tmp_spk = bin(beginY + 1)
|
120 |
+
tmp_spk = tmp_spk[2:].zfill(self.box_width)
|
121 |
+
attentionInput[-3, :-1, iBox] = torch.from_numpy(np.tile(np.array(list(tmp_spk), dtype=np.float32), (1, 2)))
|
122 |
+
tmp_spk = bin(endX + 1)
|
123 |
+
tmp_spk = tmp_spk[2:].zfill(self.box_width)
|
124 |
+
attentionInput[-2, :-1, iBox] = torch.from_numpy(np.tile(np.array(list(tmp_spk), dtype=np.float32), (1, 2)))
|
125 |
+
tmp_spk = bin(endY + 1)
|
126 |
+
tmp_spk = tmp_spk[2:].zfill(self.box_width)
|
127 |
+
attentionInput[-1, :-1, iBox] = torch.from_numpy(np.tile(np.array(list(tmp_spk), dtype=np.float32), (1, 2)))
|
128 |
+
|
129 |
+
attentionInput = attentionInput.to(self.device)
|
130 |
+
del tmpU, dilated_u, peak_cord
|
131 |
+
|
132 |
+
return attentionBox, attentionInput
|
snnTracker/spkProc/detection/motion_clustering.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2024/12/6 3:34
|
3 |
+
# @Author : Yajing Zheng
|
4 |
+
# @Email: [email protected]
|
5 |
+
# @File : motion_clustering.py
|
6 |
+
from sklearn.cluster import DBSCAN, OPTICS, SpectralClustering
|
7 |
+
from sklearn.metrics import pairwise_distances
|
8 |
+
from sklearn.preprocessing import StandardScaler
|
9 |
+
import numpy as np
|
10 |
+
import scipy.ndimage.measurements as mnts
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
class detect_object:
|
15 |
+
|
16 |
+
def __init__(self, h, w):
|
17 |
+
self.h = h
|
18 |
+
self.w = w
|
19 |
+
params = {'quantile': .3,
|
20 |
+
'eps': .4,
|
21 |
+
'damping': .9,
|
22 |
+
'preference': -200,
|
23 |
+
'n_neighbors': 10,
|
24 |
+
'min_samples': 50,
|
25 |
+
'xi': 0.05,
|
26 |
+
'min_cluster_size': 0.1,
|
27 |
+
'n_cluster': 2}
|
28 |
+
self.dbscan = DBSCAN(eps=params['eps'], min_samples=params['min_samples'], metric='precomputed')
|
29 |
+
self.optics = OPTICS(min_samples=params['min_samples'], xi=params['xi'],
|
30 |
+
min_cluster_size=params['min_cluster_size'], metric='precomputed')
|
31 |
+
self.spectral = SpectralClustering(n_clusters=params['n_cluster'], eigen_solver='arpack',
|
32 |
+
affinity='precomputed')
|
33 |
+
|
34 |
+
def get_object(self, motion_vector, max_motion=None):
|
35 |
+
# motion_vector = StandardScaler.transform(motion_vector)
|
36 |
+
if max_motion is None:
|
37 |
+
mv_idx = torch.where(torch.logical_or(motion_vector[:, :, 0] != 0, motion_vector[:, :, 1] != 0))
|
38 |
+
if len(mv_idx[0]) < 1:
|
39 |
+
return None, None
|
40 |
+
fire_idx = np.zeros((2, len(mv_idx[0])), dtype=np.int)
|
41 |
+
fire_idx[0, :] = mv_idx[0].cpu().numpy()
|
42 |
+
fire_idx[1, :] = mv_idx[1].cpu().numpy()
|
43 |
+
|
44 |
+
else:
|
45 |
+
max_motion = max_motion.cpu().numpy()
|
46 |
+
if max_motion.max() < 1:
|
47 |
+
return None, None
|
48 |
+
fire_idx = np.array(np.nonzero(max_motion))
|
49 |
+
|
50 |
+
motion_vector = motion_vector.cpu().numpy()
|
51 |
+
|
52 |
+
fire_idx = fire_idx.T
|
53 |
+
num_events = len(fire_idx)
|
54 |
+
fire_idx_ori = fire_idx
|
55 |
+
spatial_vector = StandardScaler().fit_transform(fire_idx)
|
56 |
+
|
57 |
+
motion_array = np.zeros((num_events, 2))
|
58 |
+
motion_array[:, 0] = motion_vector[fire_idx[:, 0], fire_idx[:, 1], 0]
|
59 |
+
motion_array[:, 1] = motion_vector[fire_idx[:, 0], fire_idx[:, 1], 1]
|
60 |
+
motion_array = StandardScaler().fit_transform(motion_array)
|
61 |
+
|
62 |
+
motion_dis = pairwise_distances(motion_array, metric='euclidean')
|
63 |
+
spatial_dis = pairwise_distances(spatial_vector, metric='euclidean')
|
64 |
+
total_dis = 0.5 * (motion_dis + spatial_dis)
|
65 |
+
|
66 |
+
self.dbscan.fit(total_dis)
|
67 |
+
# self.optics.fit(total_dis)
|
68 |
+
# self.spectral.fit(total_dis)
|
69 |
+
labels = self.dbscan.labels_.astype(np.int)
|
70 |
+
# labels_optics = self.optics.labels_.astype(np.int)
|
71 |
+
# labels_spetral = self.spectral.labels_.astype(np.int)
|
72 |
+
|
73 |
+
return labels, fire_idx_ori
|
74 |
+
# return labels_optics, fire_idx_ori
|
75 |
+
# return labels_spetral, fire_idx_ori
|
76 |
+
|
77 |
+
def detection_object_with_motion(self, fireID, clusterId):
|
78 |
+
L = np.zeros((self.h, self.w), dtype=np.int)
|
79 |
+
L[fireID[:, 0], fireID[:, 1]] = clusterId + 1
|
80 |
+
|
81 |
+
structure = np.array([
|
82 |
+
[1, 1, 1],
|
83 |
+
[1, 1, 1],
|
84 |
+
[1, 1, 1]
|
85 |
+
])
|
86 |
+
|
87 |
+
bboxSlices = mnts.find_objects(L)
|
88 |
+
box_num = clusterId.max() + 1
|
89 |
+
bbox = np.zeros((box_num, 4))
|
90 |
+
for iBox in range(box_num):
|
91 |
+
tmpBox = np.array(bboxSlices[iBox])
|
92 |
+
begin_X = tmpBox[0].start
|
93 |
+
end_X = tmpBox[0].stop
|
94 |
+
begin_Y = tmpBox[1].start
|
95 |
+
end_Y = tmpBox[1].stop
|
96 |
+
|
97 |
+
bbox[iBox, :] = [begin_X, begin_Y, end_X, end_Y]
|
98 |
+
|
99 |
+
# pprint(bbox)
|
100 |
+
return bbox
|
101 |
+
|
snnTracker/spkProc/detection/stdp_clustering.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from config import *
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from collections import namedtuple
|
5 |
+
|
6 |
+
detect_box = namedtuple('detect_box', ['zId', 'box', 'velocity'])
|
7 |
+
tracks = namedtuple('tracks', ['id', 'color', 'bbox', 'predbox', 'visible', 'vel', 'age', 'unvisible_count'])
|
8 |
+
trajectories = namedtuple('trajectories', ['id', 'x', 'y', 't', 'color'])
|
9 |
+
tracks_bbox = namedtuple('tracks_bbox', ['id', 't', 'x', 'y', 'h', 'w'])
|
10 |
+
|
11 |
+
|
12 |
+
class stdp_cluster():
|
13 |
+
|
14 |
+
def __init__(self, spike_h, spike_w, box_size, device):
|
15 |
+
|
16 |
+
self.spike_h = spike_h
|
17 |
+
self.spike_w = spike_w
|
18 |
+
self.box_size = box_size
|
19 |
+
self.K1 = 1
|
20 |
+
self.K2 = 5
|
21 |
+
# self.InputSize = (2 * box_size + 1)**2
|
22 |
+
self.InputSize = (2 * box_size + 1) * (2 * box_size + 5)
|
23 |
+
self.device = device
|
24 |
+
# self.InputSize = box_size * (box_size + 4)
|
25 |
+
# self.InputSize = box_size**2
|
26 |
+
|
27 |
+
# self.synaptic_weight = torch.ones(self.K1, self.InputSize, dtype=torch.float64) / self.K1
|
28 |
+
self.synaptic_weight = torch.rand(self.K2, self.K1, self.InputSize, dtype=torch.float64)
|
29 |
+
# self.synaptic_weight = self.synaptic_weight / torch.sum(self.synaptic_weight)
|
30 |
+
# self.synaptic_weight = torch.unsqueeze(self.synaptic_weight, dim=0)
|
31 |
+
# self.synaptic_weight = self.synaptic_weight.repeat(self.K2, 1, 1)
|
32 |
+
self.bias_weight = torch.ones(self.K2, 1, dtype=torch.float64) / self.K2
|
33 |
+
|
34 |
+
self.synaptic_weight = self.normalization_w(self.synaptic_weight)
|
35 |
+
self.bias_weight = self.normalization_w(self.bias_weight)
|
36 |
+
self.synaptic_weight = self.synaptic_weight.to(device)
|
37 |
+
self.bias_weight = self.bias_weight.to(device)
|
38 |
+
|
39 |
+
self.learning_rate = 0.001
|
40 |
+
self.iter_num = 5
|
41 |
+
self.stdp_coefficience = 1
|
42 |
+
|
43 |
+
self.w_up = 1
|
44 |
+
self.w_low = -8
|
45 |
+
|
46 |
+
self.w_up_tensor = torch.ones_like(self.synaptic_weight, dtype=torch.float64) * self.w_up
|
47 |
+
self.w_low_tensor = torch.ones_like(self.synaptic_weight, dtype=torch.float64) * self.w_low
|
48 |
+
|
49 |
+
self.tracks = []
|
50 |
+
self.trajectories = []
|
51 |
+
self.tracks_bbox = []
|
52 |
+
# self.seed_everything(5)
|
53 |
+
|
54 |
+
self.background_occ_fr = 10 # background oscillation rate 20 Hz
|
55 |
+
self.occ_fr = torch.Tensor([self.background_occ_fr / 20000.0])
|
56 |
+
|
57 |
+
for i_neuron in range(self.K2):
|
58 |
+
self.tracks.append(tracks(i_neuron, 255 * np.random.rand(1, 3), torch.zeros((1, 4), dtype=torch.float64),
|
59 |
+
torch.zeros((1, 4), dtype=torch.float64), 0,
|
60 |
+
torch.zeros((2,)), 0, 0))
|
61 |
+
self.trajectories.append(trajectories(i_neuron, [], [], [], self.tracks[i_neuron].color))
|
62 |
+
self.tracks_bbox.append(tracks_bbox(i_neuron, [], [], [], [], []))
|
63 |
+
|
64 |
+
|
65 |
+
def normalization_w(self, weight):
|
66 |
+
|
67 |
+
if len(weight.shape) == 2:
|
68 |
+
exp_w = torch.exp(weight)
|
69 |
+
if torch.sum(exp_w) == 0:
|
70 |
+
norm_w = weight.clone()
|
71 |
+
else:
|
72 |
+
exp_w = exp_w / torch.sum(exp_w)
|
73 |
+
norm_w = torch.log(exp_w)
|
74 |
+
|
75 |
+
del exp_w
|
76 |
+
else:
|
77 |
+
exp_w = torch.exp(weight)
|
78 |
+
w_norm = torch.sum(exp_w, dim=2)
|
79 |
+
w_norm = torch.unsqueeze(w_norm, dim=2)
|
80 |
+
w_norm = w_norm.tile(1, 1, self.InputSize)
|
81 |
+
norm_w = exp_w.detach().clone()
|
82 |
+
|
83 |
+
valid_index = torch.where(w_norm != 0)
|
84 |
+
norm_w[valid_index] = exp_w[valid_index] / w_norm[valid_index]
|
85 |
+
norm_w = torch.log(norm_w)
|
86 |
+
|
87 |
+
del exp_w, w_norm, valid_index
|
88 |
+
|
89 |
+
torch.cuda.empty_cache()
|
90 |
+
return norm_w
|
91 |
+
|
92 |
+
# winner-take-all
|
93 |
+
# @staticmethod
|
94 |
+
def wta(self, attention_spikes, synaptic_weight, bias_weight):
|
95 |
+
|
96 |
+
attention_spikes = attention_spikes.double()
|
97 |
+
synaptic_weight = torch.squeeze(synaptic_weight + abs(self.w_low))
|
98 |
+
# synaptic_weight = torch.squeeze(synaptic_weight)
|
99 |
+
intPre = torch.matmul(synaptic_weight, attention_spikes)
|
100 |
+
intPre = torch.squeeze(intPre)
|
101 |
+
intPre[intPre > 700] = 700
|
102 |
+
psp = intPre + torch.squeeze(bias_weight)
|
103 |
+
|
104 |
+
rate_norm = (torch.Tensor([1]) * 1.0).to(self.device)
|
105 |
+
sum_exp = torch.sum(torch.exp(psp))
|
106 |
+
fire_inhb = torch.log(sum_exp) - torch.log(rate_norm)
|
107 |
+
|
108 |
+
tmp_psp = torch.exp(psp - fire_inhb)
|
109 |
+
tmp_psp[torch.isinf(tmp_psp)] = 0
|
110 |
+
fire_index = torch.where(torch.rand(1).to(self.device) < tmp_psp)
|
111 |
+
Z_spike = torch.zeros(tmp_psp.shape)
|
112 |
+
Z_spike[fire_index] = 1
|
113 |
+
|
114 |
+
del intPre, psp, rate_norm, sum_exp, fire_inhb, tmp_psp, fire_index
|
115 |
+
torch.cuda.empty_cache()
|
116 |
+
return Z_spike
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def intersect(box_a, box_b):
|
120 |
+
""" We resize both tensors to [A,B,2] without new malloc:
|
121 |
+
[A,2] -> [A,1,2] -> [A,B,2]
|
122 |
+
[B,2] -> [1,B,2] -> [A,B,2]
|
123 |
+
Then we compute the area of intersect between box_a and box_b.
|
124 |
+
Args:
|
125 |
+
box_a: (tensor) bounding boxes, Shape: [A,4].
|
126 |
+
box_b: (tensor) bounding boxes, Shape: [B,4].
|
127 |
+
Return:
|
128 |
+
(tensor) intersection area, Shape: [A,B].
|
129 |
+
"""
|
130 |
+
A = box_a.size(0)
|
131 |
+
B = box_b.size(0)
|
132 |
+
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
|
133 |
+
box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
|
134 |
+
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
|
135 |
+
box_b[:, :2].unsqueeze(0).expand(A, B, 2))
|
136 |
+
inter = torch.clamp((max_xy - min_xy), min=0)
|
137 |
+
return inter[:, :, 0] * inter[:, :, 1]
|
138 |
+
|
139 |
+
def jaccard(self, box_a, box_b):
|
140 |
+
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap
|
141 |
+
is simply the intersection over union of two boxes. Here we operate on
|
142 |
+
ground truth boxes and default boxes.
|
143 |
+
E.g.:
|
144 |
+
A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
|
145 |
+
Args:
|
146 |
+
box_a: (tensor) Ground truth bounding boxes, Shape: [A,4]
|
147 |
+
box_b: (tensor) Prior boxes from priorbox layers, Shape: [B,4]
|
148 |
+
Return:
|
149 |
+
jaccard overlap: (tensor) Shape: [A, B]
|
150 |
+
"""
|
151 |
+
inter = self.intersect(box_a, box_b)
|
152 |
+
area_a = ((box_a[:, 2] - box_a[:, 0]) *
|
153 |
+
(box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
|
154 |
+
area_b = ((box_b[:, 2] - box_b[:, 0]) *
|
155 |
+
(box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
|
156 |
+
union = area_a + area_b - inter
|
157 |
+
return inter / union # [A,B]
|
158 |
+
|
159 |
+
def update_weight(self, attention_input):
|
160 |
+
|
161 |
+
n_attention = attention_input.shape[2]
|
162 |
+
predict_fire = torch.zeros(n_attention, self.K2)
|
163 |
+
synaptic_weight = self.synaptic_weight.detach().clone()
|
164 |
+
bias_weight = self.bias_weight.detach().clone()
|
165 |
+
lr_weight = torch.zeros(self.K2).to(self.device)
|
166 |
+
has_fired = np.zeros((self.K2, 1))
|
167 |
+
|
168 |
+
for iPattern in range(n_attention):
|
169 |
+
detected = -1
|
170 |
+
input_spike = torch.reshape(attention_input[:, :, iPattern], (-1, 1))
|
171 |
+
# background_noise = (torch.rand(input_spike.shape) < self.occ_fr).to(device)
|
172 |
+
# input_spike = (torch.logical_or((input_spike).type(torch.bool), background_noise)).type(torch.float32)
|
173 |
+
# input_spike = torch.reshape(attention_input, (-1, n_attention))
|
174 |
+
confusion_flag = 0
|
175 |
+
for i in range(self.iter_num):
|
176 |
+
z_spike = self.wta(input_spike, synaptic_weight, bias_weight).to(self.device)
|
177 |
+
dw_bias = self.learning_rate * (z_spike * torch.squeeze(torch.exp(-bias_weight)) - 1)
|
178 |
+
# tmp_sum = torch.sum(dw_bias, dim=1)
|
179 |
+
bias_weight += torch.unsqueeze(dw_bias, dim=1).detach()
|
180 |
+
|
181 |
+
for iZ in range(self.K2):
|
182 |
+
if z_spike[iZ] != 0 and has_fired[iZ] == 0 and (detected == -1 or iZ == detected):
|
183 |
+
has_fired[iZ] = 1
|
184 |
+
detected = iZ
|
185 |
+
# fire_idx = torch.where(z_spike[iZ, :]!=0)
|
186 |
+
tmpE = torch.exp(-synaptic_weight[iZ, :, :])
|
187 |
+
dw = self.stdp_coefficience * tmpE * torch.transpose(input_spike, 0, 1) - 1
|
188 |
+
lr_weight[iZ] += 1
|
189 |
+
|
190 |
+
synaptic_weight[iZ, :, :] += ((1.0 / lr_weight[iZ]) * dw.to(self.device))
|
191 |
+
# synaptic_weight[iZ, :, :] += self.learning_rate * dw.to(device)
|
192 |
+
|
193 |
+
synaptic_weight = torch.where(synaptic_weight < self.w_up, synaptic_weight,
|
194 |
+
self.w_up_tensor).detach()
|
195 |
+
synaptic_weight = torch.where(synaptic_weight < self.w_low, self.w_low_tensor,
|
196 |
+
synaptic_weight).detach()
|
197 |
+
|
198 |
+
synaptic_weight = self.normalization_w(synaptic_weight)
|
199 |
+
bias_weight = self.normalization_w(bias_weight)
|
200 |
+
predict_fire[iPattern, iZ] = torch.Tensor([1])
|
201 |
+
|
202 |
+
# predict_fire[iPattern, :] = z_spike.detach()
|
203 |
+
# predict_fire = z_spike.detach()
|
204 |
+
# print(synaptic_weight.max())
|
205 |
+
# print(synaptic_weight.min())
|
206 |
+
|
207 |
+
del n_attention, lr_weight
|
208 |
+
torch.cuda.empty_cache()
|
209 |
+
return predict_fire, synaptic_weight, bias_weight
|
210 |
+
|
211 |
+
def seed_everything(self, seed=11):
|
212 |
+
np.random.seed(seed)
|
213 |
+
torch.manual_seed(seed)
|
214 |
+
torch.cuda.manual_seed(seed)
|
215 |
+
torch.cuda.manual_seed_all(seed)
|
216 |
+
torch.backends.cudnn.deterministic = True
|
217 |
+
torch.backends.cudnn.benchmark = False
|
218 |
+
|
219 |
+
return
|
220 |
+
|
221 |
+
@staticmethod
|
222 |
+
def detect_object(predict_fire, attention_box, motion_id, motion_vector, **kwargs):
|
223 |
+
|
224 |
+
spike_h = kwargs.get('spike_h')
|
225 |
+
spike_w = kwargs.get('spike_w')
|
226 |
+
device = kwargs.get('device')
|
227 |
+
nAttention = attention_box.shape[0]
|
228 |
+
boxId = torch.zeros(nAttention, 1)
|
229 |
+
predBox = torch.zeros((nAttention, 4), dtype=torch.int)
|
230 |
+
velocities = torch.zeros(nAttention, 2).to(device)
|
231 |
+
predict_box = []
|
232 |
+
|
233 |
+
for iPattern in range(nAttention):
|
234 |
+
z_spike = predict_fire[iPattern, :]
|
235 |
+
if torch.any(z_spike != 0):
|
236 |
+
if len(torch.where(z_spike != 0)[0]) > 1:
|
237 |
+
print('check')
|
238 |
+
|
239 |
+
tmp_fired = torch.where(z_spike != 0)[0]
|
240 |
+
boxId[iPattern] = tmp_fired[0] + 1
|
241 |
+
|
242 |
+
x = attention_box[iPattern, 0]
|
243 |
+
y = attention_box[iPattern, 1]
|
244 |
+
end_x = attention_box[iPattern, 2]
|
245 |
+
end_y = attention_box[iPattern, 3]
|
246 |
+
|
247 |
+
tmp_motion = torch.zeros(spike_h, spike_w)
|
248 |
+
tmp_motion[x:end_x + 1, y:end_y + 1] = motion_id[x:end_x + 1, y:end_y + 1].clone()
|
249 |
+
|
250 |
+
motion_index2d = torch.where(tmp_motion != 0)
|
251 |
+
if len(motion_index2d[0]) == 0:
|
252 |
+
continue
|
253 |
+
|
254 |
+
motion_num = len(motion_index2d[0])
|
255 |
+
block_veloctiy = torch.zeros(motion_num, 2).to(device)
|
256 |
+
block_veloctiy[:, 0] = motion_vector[motion_index2d[0], motion_index2d[1], 0].clone()
|
257 |
+
block_veloctiy[:, 1] = motion_vector[motion_index2d[0], motion_index2d[1], 1].clone()
|
258 |
+
tmp_veloctiy = torch.mean(block_veloctiy, dim=0)
|
259 |
+
velocities[iPattern, :] = tmp_veloctiy.data
|
260 |
+
predBox[iPattern, :] = attention_box[iPattern, :]
|
261 |
+
|
262 |
+
predict_box.append(detect_box(boxId[iPattern],
|
263 |
+
torch.unsqueeze(predBox[iPattern], dim=0),
|
264 |
+
velocities[iPattern]))
|
265 |
+
# else:
|
266 |
+
# print('no tracking neuron fire..')
|
267 |
+
|
268 |
+
del boxId, predBox, velocities
|
269 |
+
torch.cuda.empty_cache()
|
270 |
+
|
271 |
+
return predict_box
|
272 |
+
|
273 |
+
def update_tracks(self, detect_objects, sw, bw, timestep):
|
274 |
+
|
275 |
+
objects_num = len(detect_objects)
|
276 |
+
id_check = torch.zeros(self.K2, 1)
|
277 |
+
AssignTrk = []
|
278 |
+
|
279 |
+
for iObject in range(objects_num):
|
280 |
+
tmp_object = detect_objects[iObject]
|
281 |
+
id = int(tmp_object.zId.detach().item())
|
282 |
+
box = tmp_object.box
|
283 |
+
velocity = tmp_object.velocity
|
284 |
+
|
285 |
+
if id_check[id - 1] != 0:
|
286 |
+
if id in AssignTrk:
|
287 |
+
# print('id %d repeat' % (id-1))
|
288 |
+
AssignTrk.remove(id)
|
289 |
+
continue
|
290 |
+
else:
|
291 |
+
id_check[id - 1] = 1
|
292 |
+
|
293 |
+
pred_box = self.tracks[id - 1].predbox
|
294 |
+
boxes_iou = self.jaccard(box, pred_box)
|
295 |
+
unvisible_count = self.tracks[id - 1].unvisible_count
|
296 |
+
if ~(self.tracks[id - 1].predbox[0, 3] != 0 and self.tracks[id - 1].age > 15
|
297 |
+
and boxes_iou < 0.6):
|
298 |
+
self.tracks[id - 1] = self.tracks[id - 1]._replace(bbox=box)
|
299 |
+
beginX = box[0, 0]
|
300 |
+
beginY = box[0, 1]
|
301 |
+
endX = box[0, 2]
|
302 |
+
endY = box[0, 3]
|
303 |
+
|
304 |
+
beginX = beginX + velocity[0] >= 0 and (beginX + velocity[0]) or 0
|
305 |
+
beginY = beginY + velocity[1] >= 0 and (beginY + velocity[1]) or 0
|
306 |
+
# endX = beginX + self.box_size * 2 < self.spike_h and (beginX + self.box_size * 2) or (self.spike_h - 1)
|
307 |
+
# endY = beginY + self.box_size * 2 < self.spike_w and (beginY + self.box_size * 2) or (self.spike_w - 1)
|
308 |
+
endX = endX + velocity[0] < self.spike_h and (endX + velocity[0]) or (self.spike_h - 1)
|
309 |
+
endY = endY + velocity[1] < self.spike_w and (endY + velocity[1]) or (self.spike_w - 1)
|
310 |
+
|
311 |
+
tmp_box = torch.tensor([beginX, beginY, endX, endY])
|
312 |
+
tmp_box = torch.unsqueeze(tmp_box, dim=0)
|
313 |
+
self.tracks[id - 1] = self.tracks[id - 1]._replace(predbox=tmp_box)
|
314 |
+
|
315 |
+
self.tracks[id - 1] = self.tracks[id - 1]._replace(visible=1)
|
316 |
+
self.tracks[id - 1] = self.tracks[id - 1]._replace(vel=velocity)
|
317 |
+
self.tracks[id - 1] = self.tracks[id - 1]._replace(unvisible_count=0)
|
318 |
+
self.tracks[id - 1] = self.tracks[id - 1]._replace(age=self.tracks[id - 1].age + 1)
|
319 |
+
|
320 |
+
# update the trajectories
|
321 |
+
self.trajectories[id - 1].x.append((box[0, 0] + self.box_size).item())
|
322 |
+
self.trajectories[id - 1].y.append((box[0, 1] + self.box_size).item())
|
323 |
+
self.trajectories[id - 1].t.append(timestep)
|
324 |
+
|
325 |
+
# Check if beginX, beginY, endX, endY are int; otherwise, use .item()
|
326 |
+
self.tracks_bbox[id - 1].x.append(beginY if isinstance(beginY, int) else beginY.item())
|
327 |
+
self.tracks_bbox[id - 1].y.append(beginX if isinstance(beginX, int) else beginX.item())
|
328 |
+
self.tracks_bbox[id - 1].h.append(
|
329 |
+
(endX - beginX) if isinstance(endX, int) and isinstance(beginX, int) else (endX - beginX).item())
|
330 |
+
self.tracks_bbox[id - 1].w.append(
|
331 |
+
(endY - beginY) if isinstance(endY, int) and isinstance(beginY, int) else (endY - beginY).item())
|
332 |
+
|
333 |
+
self.tracks_bbox[id - 1].t.append(timestep)
|
334 |
+
AssignTrk.append(id)
|
335 |
+
# print('tracks %d velocity dx: %f dy: %f' % (id, velocity[0], velocity[1]))
|
336 |
+
|
337 |
+
all_id = list(range(1, self.K2 + 1, 1))
|
338 |
+
noAssign = [x for x in all_id if x not in AssignTrk]
|
339 |
+
|
340 |
+
noAssign_num = self.K2 - len(AssignTrk)
|
341 |
+
|
342 |
+
for iObject in range(noAssign_num):
|
343 |
+
id = noAssign[iObject]
|
344 |
+
unvisible_count = self.tracks[id - 1].unvisible_count
|
345 |
+
if unvisible_count > 5:
|
346 |
+
self.tracks[id - 1] = self.tracks[id - 1]._replace(age=0)
|
347 |
+
self.tracks[id - 1] = self.tracks[id - 1]._replace(visible=0)
|
348 |
+
# sw[id-1, :, :] = 1 / self.K1
|
349 |
+
# bw[id-1] = 1 / self.K2
|
350 |
+
else:
|
351 |
+
if self.tracks[id - 1].predbox[0, 2] != 0:
|
352 |
+
self.tracks[id - 1] = self.tracks[id - 1]._replace(bbox=self.tracks[id - 1].predbox)
|
353 |
+
beginX = self.tracks[id - 1].predbox[0, 0].item()
|
354 |
+
beginY = self.tracks[id - 1].predbox[0, 1].item()
|
355 |
+
endX = self.tracks[id - 1].predbox[0, 2].item()
|
356 |
+
endY = self.tracks[id - 1].predbox[0, 3].item()
|
357 |
+
|
358 |
+
beginX = beginX + self.tracks[id - 1].vel[0] >= 0 and (beginX + self.tracks[id - 1].vel[0]) or 0
|
359 |
+
beginY = beginY + self.tracks[id - 1].vel[1] >= 0 and (beginY + self.tracks[id - 1].vel[1]) or 0
|
360 |
+
# endX = beginX + self.box_size * 2 < self.spike_h and (beginX + self.box_size * 2) or (self.spike_h - 1)
|
361 |
+
# endY = beginY + self.box_size * 2 < self.spike_w and (beginY + self.box_size * 2) or (self.spike_w - 1)
|
362 |
+
endX = endX + self.tracks[id - 1].vel[0] < self.spike_h and (endX + self.tracks[id - 1].vel[0]) or (
|
363 |
+
self.spike_h - 1)
|
364 |
+
endY = endY + self.tracks[id - 1].vel[1] < self.spike_w and (endY + self.tracks[id - 1].vel[1]) or (
|
365 |
+
self.spike_w - 1)
|
366 |
+
|
367 |
+
pred_box = torch.tensor([beginX, beginY, endX, endY])
|
368 |
+
pred_box = torch.unsqueeze(pred_box, dim=0)
|
369 |
+
self.tracks[id - 1] = self.tracks[id - 1]._replace(predbox=pred_box)
|
370 |
+
|
371 |
+
self.trajectories[id - 1].x.append(
|
372 |
+
(beginX + self.box_size) if isinstance(beginX, int) else (beginX + self.box_size).item())
|
373 |
+
self.trajectories[id - 1].y.append(
|
374 |
+
(beginY + self.box_size) if isinstance(beginY, int) else (beginY + self.box_size).item())
|
375 |
+
self.trajectories[id - 1].t.append(timestep)
|
376 |
+
|
377 |
+
# Check if beginX, beginY, endX, endY are int; otherwise, use .item()
|
378 |
+
self.tracks_bbox[id - 1].x.append(beginY if isinstance(beginY, int) else beginY.item())
|
379 |
+
self.tracks_bbox[id - 1].y.append(beginX if isinstance(beginX, int) else beginX.item())
|
380 |
+
self.tracks_bbox[id - 1].h.append(
|
381 |
+
(endX - beginX) if isinstance(endX, int) and isinstance(beginX, int) else (
|
382 |
+
endX - beginX).item())
|
383 |
+
self.tracks_bbox[id - 1].w.append(
|
384 |
+
(endY - beginY) if isinstance(endY, int) and isinstance(beginY, int) else (
|
385 |
+
endY - beginY).item())
|
386 |
+
self.tracks_bbox[id - 1].t.append(timestep)
|
387 |
+
# print('predicting location of object %d the %d time' % (id, unvisible_count))
|
388 |
+
|
389 |
+
# print('tracks %d predictive velocity dx: %f, dy: %f' % (
|
390 |
+
# id, self.tracks[id - 1].vel[0], self.tracks[id - 1].vel[1]))
|
391 |
+
if ~(torch.all(self.synaptic_weight == 1)):
|
392 |
+
sw[id - 1, :, :] = self.synaptic_weight[id - 1, :, :].detach().clone()
|
393 |
+
bw[id - 1] = self.bias_weight[id - 1].detach().clone()
|
394 |
+
# print('correct the weight')
|
395 |
+
self.tracks[id - 1] = self.tracks[id - 1]._replace(unvisible_count=self.tracks[id - 1].unvisible_count + 1)
|
396 |
+
|
397 |
+
return sw, bw
|
snnTracker/spkProc/filters/stp_filters_torch.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2021/11/19 16:25
|
3 |
+
# @Author : Yajing Zheng
|
4 |
+
# @File : stp_filters_torch.py
|
5 |
+
import copy
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
class STPFilter:
|
11 |
+
|
12 |
+
def __init__(self, spike_h, spike_w, device, diff_time=1, **STPargs):
|
13 |
+
self.spike_h = spike_h
|
14 |
+
self.spike_w = spike_w
|
15 |
+
self.device = device
|
16 |
+
|
17 |
+
# specify stp parameters
|
18 |
+
if STPargs.get('u0', None) is None:
|
19 |
+
self.u0 = 0.1
|
20 |
+
self.D = 0.02
|
21 |
+
self.F = 1.7
|
22 |
+
self.f = 0.11
|
23 |
+
self.time_unit = 2000
|
24 |
+
|
25 |
+
else:
|
26 |
+
self.u0 = STPargs.get('u0')
|
27 |
+
self.D = STPargs.get('D')
|
28 |
+
self.F = STPargs.get('F')
|
29 |
+
self.f = STPargs.get('f')
|
30 |
+
self.time_unit = STPargs.get('time_unit')
|
31 |
+
|
32 |
+
self.r0 = 1
|
33 |
+
|
34 |
+
self.diff_time = diff_time # duration of window for record past dynamics for calculating the differnece
|
35 |
+
self.R = torch.ones(self.spike_h, self.spike_w) * self.r0
|
36 |
+
self.u = torch.ones(self.spike_h, self.spike_w) * self.u0
|
37 |
+
self.r_old = torch.ones(self.diff_time, self.spike_h, self.spike_w) * self.r0
|
38 |
+
|
39 |
+
self.R = self.R.to(self.device)
|
40 |
+
self.u = self.u.to(self.device)
|
41 |
+
self.r_old = self.r_old.to(self.device)
|
42 |
+
|
43 |
+
# LIF detect layer parameters
|
44 |
+
self.detectVoltage = torch.zeros(self.spike_h, self.spike_w).to(self.device)
|
45 |
+
if STPargs.get('lifSize', None) is None:
|
46 |
+
lifSize = 3
|
47 |
+
paddingSize = 1
|
48 |
+
else:
|
49 |
+
lifSize = STPargs.get('lifSize')
|
50 |
+
paddingSize = int((lifSize - 1) / 2)
|
51 |
+
|
52 |
+
self.lifConv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(lifSize, lifSize),
|
53 |
+
padding=(paddingSize, paddingSize),
|
54 |
+
bias=False)
|
55 |
+
self.lifConv.weight.data = torch.ones(1, 1, lifSize, lifSize) * 3.0
|
56 |
+
|
57 |
+
self.lifConv = self.lifConv.to(self.device)
|
58 |
+
if STPargs.get('filterThr', None) is None:
|
59 |
+
self.filterThr = 0.1 # filter threshold
|
60 |
+
self.voltageMin = -8
|
61 |
+
self.lifThr = 2
|
62 |
+
else:
|
63 |
+
self.filterThr = STPargs.get('filterThr')
|
64 |
+
self.voltageMin = STPargs.get('voltageMin')
|
65 |
+
self.lifThr = STPargs.get('lifThr')
|
66 |
+
|
67 |
+
self.filter_spk = torch.zeros(self.spike_h, self.spike_w).to(self.device)
|
68 |
+
self.lif_spk = torch.zeros(self.spike_h, self.spike_w).to(self.device)
|
69 |
+
self.spikePrevMnt = torch.zeros([self.spike_h, self.spike_w], device=self.device)
|
70 |
+
self.stp_gradient = 0
|
71 |
+
self.adjusted_threshold = torch.zeros(self.spike_h, self.spike_w).to(self.device)
|
72 |
+
|
73 |
+
def update_dynamics(self, curT, spikes):
|
74 |
+
|
75 |
+
spikeCurMnt = self.spikePrevMnt.detach().clone()
|
76 |
+
spike_bool = spikes.bool()
|
77 |
+
spikeCurMnt[spike_bool] = curT + 1
|
78 |
+
dttimes = spikeCurMnt - self.spikePrevMnt
|
79 |
+
dttimes = dttimes / self.time_unit
|
80 |
+
exp_D = torch.exp((-dttimes[spike_bool] / self.D))
|
81 |
+
self.R[spike_bool] = 1 - (1 - self.R[spike_bool] * (1 - self.u[spike_bool])) * exp_D
|
82 |
+
exp_F = torch.exp((-dttimes[spike_bool] / self.F))
|
83 |
+
self.u[spike_bool] = self.u0 + (
|
84 |
+
self.u[spike_bool] + self.f * (1 - self.u[spike_bool]) - self.u0) * exp_F
|
85 |
+
|
86 |
+
tmp_diff = torch.abs(self.R - self.r_old[0])
|
87 |
+
# 根据梯度动态调整滤波器阈值
|
88 |
+
self.stp_gradient = (0.5 * self.stp_gradient + 0.5 * torch.div(tmp_diff, self.R))
|
89 |
+
gradient_sqrt = torch.from_numpy(np.sqrt(self.stp_gradient.cpu().numpy()) + 1).to(self.device)
|
90 |
+
self.adjusted_threshold = torch.div(self.filterThr, gradient_sqrt)
|
91 |
+
|
92 |
+
self.filter_spk[:] = 0
|
93 |
+
# self.filter_spk[spike_bool & (tmp_diff >= self.filterThr)] = 1
|
94 |
+
self.filter_spk[spike_bool & (tmp_diff >= self.adjusted_threshold)] = 1
|
95 |
+
|
96 |
+
if curT < self.diff_time:
|
97 |
+
self.r_old[curT] = self.R.detach().clone()
|
98 |
+
else:
|
99 |
+
self.r_old[0:-1] = self.r_old[1:].detach().clone()
|
100 |
+
self.r_old[-1] = self.R.detach().clone()
|
101 |
+
self.spikePrevMnt = spikeCurMnt.detach().clone()
|
102 |
+
del spikeCurMnt, dttimes, exp_D, exp_F, tmp_diff
|
103 |
+
|
104 |
+
def update_dynamic_offline(self, spikes, intervals):
|
105 |
+
|
106 |
+
isi_num = intervals.shape[0]
|
107 |
+
R = torch.ones(isi_num, self.spike_h, self.spike_w) * self.r0
|
108 |
+
u = torch.ones(isi_num, self.spike_h, self.spike_w) * self.u0
|
109 |
+
prev_isi = intervals[0, :, :]
|
110 |
+
|
111 |
+
for t in range(1, isi_num):
|
112 |
+
tmp_isi = intervals[t, :, :]
|
113 |
+
update_idx = (tmp_isi != prev_isi) & (spikes[t, :, :] == 1) | (tmp_isi == 1)
|
114 |
+
tmp_isi = torch.from_numpy(tmp_isi).to(self.device).float()
|
115 |
+
|
116 |
+
exp_D = torch.exp((-tmp_isi[update_idx] / self.D))
|
117 |
+
self.R[update_idx] = 1 - (1 - self.R[update_idx] * (1 - self.u[update_idx])) * exp_D
|
118 |
+
exp_F = torch.exp((-tmp_isi[update_idx] / self.F))
|
119 |
+
self.u[update_idx] = self.u0 + (
|
120 |
+
self.u[update_idx] + self.f * (1 - self.u[update_idx]) - self.u0) * exp_F
|
121 |
+
|
122 |
+
tmp_r = self.R.detach().clone()
|
123 |
+
tmp_u = self.u.detach().clone()
|
124 |
+
R[t, :, :] = copy.deepcopy(tmp_r)
|
125 |
+
u[t, :, :] = copy.deepcopy(tmp_u)
|
126 |
+
|
127 |
+
return R, u
|
128 |
+
|
129 |
+
def local_connect(self, spikes):
|
130 |
+
inputSpk = torch.reshape(spikes, (1, 1, self.spike_h, self.spike_w)).float()
|
131 |
+
# tmp_fired = spikes != 0
|
132 |
+
self.detectVoltage[spikes == False] -= 1
|
133 |
+
tmpRes = self.lifConv(inputSpk)
|
134 |
+
tmpRes = torch.squeeze(tmpRes).to(self.device)
|
135 |
+
self.detectVoltage += tmpRes.data
|
136 |
+
self.detectVoltage[self.detectVoltage < self.voltageMin] = self.voltageMin
|
137 |
+
|
138 |
+
self.lif_spk[:] = 0
|
139 |
+
self.lif_spk[self.detectVoltage >= self.lifThr] = 1
|
140 |
+
self.detectVoltage[self.detectVoltage >= self.lifThr] *= 0.8
|
141 |
+
# self.detectVoltage[(self.detectVoltage < self.lifThr) & (self.detectVoltage > 0)] = 0
|
142 |
+
|
143 |
+
del inputSpk, tmpRes
|
144 |
+
|
145 |
+
def local_connect_offline(self, spikes):
|
146 |
+
timestamps = spikes.shape[0]
|
147 |
+
tmp_voltage = []
|
148 |
+
lif_spk = []
|
149 |
+
|
150 |
+
for iSpk in range(timestamps):
|
151 |
+
tmp_spikes = spikes[iSpk]
|
152 |
+
tmp_spk = torch.from_numpy(spikes[iSpk]).to(self.device)
|
153 |
+
inputSpk = torch.reshape(tmp_spk, (1, 1, self.spike_h, self.spike_w)).float()
|
154 |
+
# tmp_fired = spikes != 0
|
155 |
+
self.detectVoltage[tmp_spikes == 0] -= 1
|
156 |
+
tmpRes = self.lifConv(inputSpk)
|
157 |
+
tmpRes = torch.squeeze(tmpRes).to(self.device)
|
158 |
+
self.detectVoltage += tmpRes.data
|
159 |
+
self.detectVoltage[self.detectVoltage < self.voltageMin] = self.voltageMin
|
160 |
+
|
161 |
+
self.lif_spk[:] = 0
|
162 |
+
self.lif_spk[self.detectVoltage >= self.lifThr] = 1
|
163 |
+
# self.detectVoltage[(self.detectVoltage < self.lifThr) & (self.detectVoltage > 0)] = 0
|
164 |
+
self.detectVoltage[self.detectVoltage >= self.lifThr] *= 0.8
|
165 |
+
voltage = self.detectVoltage.cpu().detach().numpy()
|
166 |
+
tmp_voltage.append(copy.deepcopy(voltage))
|
167 |
+
lif_spk.append(self.lif_spk.cpu().detach().numpy())
|
168 |
+
|
169 |
+
del inputSpk, tmpRes
|
170 |
+
return tmp_voltage, lif_spk
|
snnTracker/spkProc/motion/motion_detection.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import get_kernel, get_transform_matrix_new, visualize_images
|
2 |
+
import torchgeometry as tgm
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch
|
6 |
+
|
7 |
+
class motion_estimation:
|
8 |
+
|
9 |
+
def __init__(self, dvs_h, dvs_w, device, logger):
|
10 |
+
|
11 |
+
self.dvs_h = dvs_h
|
12 |
+
self.dvs_w = dvs_w
|
13 |
+
self.device = device
|
14 |
+
self.logger = logger
|
15 |
+
|
16 |
+
# motion parameters
|
17 |
+
self.orientation = range(0, 180 - 1, int(180 / 4))
|
18 |
+
# eight moving direction
|
19 |
+
'''
|
20 |
+
self.ori = torch.Tensor(np.array([[-1, -1],
|
21 |
+
[0, -1],
|
22 |
+
[1, -1],
|
23 |
+
[-1, 0],
|
24 |
+
[1, 0],
|
25 |
+
[-1, 1],
|
26 |
+
[0, 1],
|
27 |
+
[1, 1]], dtype=np.uint8)).to(self.device)
|
28 |
+
'''
|
29 |
+
# self.ori = np.array([[-1, -1],
|
30 |
+
# [0, -1],
|
31 |
+
# [1, -1],
|
32 |
+
# [1, 0],
|
33 |
+
# [-1, 0],
|
34 |
+
# [-1, 1],
|
35 |
+
# [0, 1],
|
36 |
+
# [1, 1]], dtype=np.int32)
|
37 |
+
|
38 |
+
self.ori = np.array([[1, 0],
|
39 |
+
[1, 1],
|
40 |
+
[0, 1],
|
41 |
+
[-1, 1],
|
42 |
+
[-1, 0],
|
43 |
+
[-1, -1],
|
44 |
+
[0, -1],
|
45 |
+
[1, -1]], dtype=np.int32)
|
46 |
+
# self.ori = np.array(self.ori, dtype=np.int)
|
47 |
+
# self.speed = torch.Tensor(np.array([1, 2, 3, 4], np.uint8)).to(self.device)
|
48 |
+
self.speed = np.array([1, 2, 3, 4, 5, 6], dtype=np.int32)
|
49 |
+
# self.speed = np.array([1], dtype=np.int32)
|
50 |
+
# self.ori_x = torch.from_numpy(self.ori[:, 0]).to(self.device)
|
51 |
+
# self.ori_y = torch.from_numpy(self.ori[:, 1]).to(self.device)
|
52 |
+
self.ori_x = torch.from_numpy(np.expand_dims(self.ori[:, 0], axis=1)).to(self.device).float()
|
53 |
+
self.ori_y = torch.from_numpy(np.expand_dims(self.ori[:, 1], axis=1)).to(self.device).float()
|
54 |
+
|
55 |
+
self.warp_matrix = get_transform_matrix_new(self.ori, self.speed, self.dvs_w, self.dvs_h, self.device)
|
56 |
+
self.track_pre = torch.zeros(self.dvs_h, self.dvs_w)
|
57 |
+
|
58 |
+
self.num_ori = len(self.ori)
|
59 |
+
self.num_speed = len(self.speed)
|
60 |
+
self.motion_pattern_num = self.num_ori * self.num_speed
|
61 |
+
self.motion_weight = torch.ones(self.motion_pattern_num, 1, self.dvs_h, self.dvs_w) / self.motion_pattern_num
|
62 |
+
self.tracking_threshold = 1
|
63 |
+
|
64 |
+
# self.local_pool_size = 21
|
65 |
+
self.local_pool_size = 11
|
66 |
+
# self.local_pool_size = 5
|
67 |
+
padding_width = int((self.local_pool_size - 1) / 2)
|
68 |
+
self.pool_kernel = torch.nn.Conv2d(in_channels=1, out_channels=1,
|
69 |
+
kernel_size=(self.local_pool_size, self.local_pool_size),
|
70 |
+
padding=(padding_width, padding_width), bias=False)
|
71 |
+
self.pool_kernel.weight.data = torch.ones(1, 1, self.local_pool_size, self.local_pool_size)
|
72 |
+
|
73 |
+
self.gaussian_kernel = torch.nn.Conv2d(in_channels=1, out_channels=1,
|
74 |
+
kernel_size=(self.local_pool_size, self.local_pool_size),
|
75 |
+
padding=(padding_width, padding_width), bias=False)
|
76 |
+
tmp_filter = get_kernel(self.local_pool_size, round(self.local_pool_size / 4))
|
77 |
+
tmp_filter = tmp_filter.reshape((1, 1, self.local_pool_size, self.local_pool_size))
|
78 |
+
self.gaussian_kernel.weight.data = torch.from_numpy(tmp_filter).float()
|
79 |
+
|
80 |
+
# local wta inhibition size
|
81 |
+
# inh_size = 15
|
82 |
+
self.inh_size = 25
|
83 |
+
# inh_size = 11
|
84 |
+
self.padding_width = int((self.inh_size - 1) / 2)
|
85 |
+
self.inhb_kernel = torch.nn.Conv2d(in_channels=1, out_channels=1,
|
86 |
+
kernel_size=(self.inh_size, self.inh_size),
|
87 |
+
padding=(self.padding_width, self.padding_width), bias=False)
|
88 |
+
self.inhb_kernel.weight.data = torch.ones(1, 1, self.inh_size, self.inh_size)
|
89 |
+
self.inhb_threshold = 5
|
90 |
+
|
91 |
+
self.track_pre = self.track_pre.to(self.device)
|
92 |
+
self.motion_weight = self.motion_weight.to(self.device)
|
93 |
+
self.pool_kernel = self.pool_kernel.to(self.device)
|
94 |
+
self.gaussian_kernel = self.gaussian_kernel.to(self.device)
|
95 |
+
self.inhb_kernel = self.inhb_kernel.to(self.device)
|
96 |
+
|
97 |
+
# cc_motion = [[0, 33, 238],
|
98 |
+
# [79, 0, 255],
|
99 |
+
# [229, 0, 237],
|
100 |
+
# [188, 0, 26],
|
101 |
+
# [191, 198, 0],
|
102 |
+
# [129, 241, 0],
|
103 |
+
# [0, 205, 106],
|
104 |
+
# [0, 205, 198]]
|
105 |
+
|
106 |
+
cc_motion = [[0, 255, 255],
|
107 |
+
[205, 95, 85],
|
108 |
+
[11, 134, 184],
|
109 |
+
[255, 255, 0],
|
110 |
+
[154, 250, 0],
|
111 |
+
[147, 20, 255],
|
112 |
+
[240, 32, 160],
|
113 |
+
[48, 48, 255]]
|
114 |
+
|
115 |
+
cc_motion = np.transpose(np.array(cc_motion, dtype=np.float32))
|
116 |
+
self.cc_motion = torch.from_numpy(cc_motion / 255)
|
117 |
+
self.cc_motion = self.cc_motion.to(self.device)
|
118 |
+
self.learning_rate = 0.1
|
119 |
+
|
120 |
+
'''
|
121 |
+
self.dw_ltp = torch.zeros(self.motion_pattern_num, 1, self.dvs_h, self.dvs_w)
|
122 |
+
self.dw_ltd = torch.zeros(self.motion_pattern_num, 1, self.dvs_h, self.dvs_w)
|
123 |
+
self.dw_ltp = self.dw_ltp.to(self.device)
|
124 |
+
self.dw_ltd = self.dw_ltd.to(self.device)
|
125 |
+
'''
|
126 |
+
|
127 |
+
def stdp_tracking(self, spikes):
|
128 |
+
track_post = torch.reshape(spikes, (1, 1, self.dvs_h, self.dvs_w))
|
129 |
+
tmp_pool = self.pool_kernel(track_post)
|
130 |
+
tmp_pool = tmp_pool.repeat(self.motion_pattern_num, 1, 1, 1)
|
131 |
+
|
132 |
+
predict_fired = torch.zeros(self.motion_pattern_num, 1, self.dvs_h, self.dvs_w).to(self.device)
|
133 |
+
fire_idx = torch.where(spikes != 0)
|
134 |
+
|
135 |
+
for i_ori in range(self.num_ori):
|
136 |
+
for i_speed in range(self.num_speed):
|
137 |
+
i_motion = i_ori * self.num_speed + i_speed
|
138 |
+
x = fire_idx[0] + self.ori[i_ori, 0] * self.speed[i_speed]
|
139 |
+
y = fire_idx[1] + self.ori[i_ori, 1] * self.speed[i_speed]
|
140 |
+
invalid_idx = torch.logical_or(torch.logical_or(x > self.dvs_h - 1, x < 0),
|
141 |
+
torch.logical_or(y > self.dvs_w - 1, y < 0))
|
142 |
+
x[invalid_idx] = 0
|
143 |
+
y[invalid_idx] = 0
|
144 |
+
predict_fired[i_motion, 0, x, y] = 1
|
145 |
+
|
146 |
+
# track_post = track_post.repeat(self.motion_pattern_num, 1, 1, 1)
|
147 |
+
# grid = F.affine_grid(self.warp_matrix, track_post.shape)
|
148 |
+
# predict_fired = F.grid_sample(track_post, grid, padding_mode='zeros', align_corners=True)
|
149 |
+
# # invalid_index = torch.where(torch.logical_and(predict_fired != 0, predict_fired != 1))
|
150 |
+
# invalid_index = torch.where(predict_fired>0)
|
151 |
+
# predict_fired[invalid_index] = 1
|
152 |
+
# predict_fired[torch.where(predict_fired<1)] = 0
|
153 |
+
|
154 |
+
track_pre_exp = torch.unsqueeze(self.track_pre, 0).repeat(self.motion_pattern_num, 1, 1)
|
155 |
+
track_pre_exp = torch.unsqueeze(track_pre_exp, 1)
|
156 |
+
|
157 |
+
# STDP update the motion weight
|
158 |
+
dw_ltd = torch.zeros(self.motion_pattern_num, 1, self.dvs_h, self.dvs_w).to(self.device)
|
159 |
+
dw_ltp = torch.zeros(self.motion_pattern_num, 1, self.dvs_h, self.dvs_w).to(self.device)
|
160 |
+
|
161 |
+
tmp_bool = torch.eq(predict_fired, track_pre_exp)
|
162 |
+
index = torch.where(torch.logical_and(tmp_bool, track_pre_exp == 1))
|
163 |
+
if len(index[0]) != 0:
|
164 |
+
dw_ltp[index] = 1
|
165 |
+
|
166 |
+
index = torch.where(torch.logical_and(~tmp_bool, predict_fired == 1))
|
167 |
+
if len(index[0]) != 0:
|
168 |
+
dw_ltd[index] = 2
|
169 |
+
|
170 |
+
dw_ltp = self.pool_kernel(dw_ltp)
|
171 |
+
dw_ltd = self.pool_kernel(dw_ltd)
|
172 |
+
|
173 |
+
# dw_ltp = self.gaussian_kernel(dw_ltp)
|
174 |
+
# dw_ltd = self.gaussian_kernel(dw_ltd)
|
175 |
+
|
176 |
+
# dw = dw_ltp - dw_ltd
|
177 |
+
# dw = self.gaussian_kernel(dw_ltp - dw_ltd)
|
178 |
+
# tmp_pool[torch.where(tmp_pool == 0)] = 1
|
179 |
+
dw = torch.div((dw_ltp - dw_ltd), tmp_pool)
|
180 |
+
# dw = dw / tmp_pool
|
181 |
+
# dw = dw_ltp - dw_ltd
|
182 |
+
self.motion_weight += self.learning_rate * dw.detach().clone()
|
183 |
+
|
184 |
+
max_weight, _ = torch.max(self.motion_weight, dim=0)
|
185 |
+
min_weight, _ = torch.min(self.motion_weight, dim=0)
|
186 |
+
|
187 |
+
for iMotion in range(self.motion_pattern_num):
|
188 |
+
tmp_weight = self.motion_weight[iMotion, :, :, :].detach()
|
189 |
+
tmp_weight = (tmp_weight - min_weight) / (max_weight - min_weight)
|
190 |
+
self.motion_weight[iMotion, :, :, :] = tmp_weight.detach()
|
191 |
+
|
192 |
+
# self.motion_weight.data = F.normalize(self.motion_weight, p=2, dim=0)
|
193 |
+
self.motion_weight[torch.isnan(self.motion_weight)] = 0
|
194 |
+
# self.motion_weight[torch.isinf(self.motion_weight)] = 0
|
195 |
+
self.track_pre = spikes.detach().clone()
|
196 |
+
|
197 |
+
del track_post, tmp_pool, predict_fired, track_pre_exp, tmp_bool, dw
|
198 |
+
del tmp_weight, max_weight, min_weight, spikes
|
199 |
+
del dw_ltd, dw_ltp
|
200 |
+
torch.cuda.empty_cache()
|
201 |
+
|
202 |
+
def local_wta(self, spikes, timestamp, visualize=False):
|
203 |
+
input_spike = torch.reshape(spikes, (1, 1, self.dvs_h, self.dvs_w))
|
204 |
+
|
205 |
+
motion_vector_layer1 = torch.zeros(self.dvs_h, self.dvs_w, 2, dtype=torch.float32).to(self.device)
|
206 |
+
max_w, max_wid = torch.max(self.motion_weight, dim=0)
|
207 |
+
max_wid = torch.squeeze(max_wid)
|
208 |
+
speedId = (max_wid % self.num_speed).detach()
|
209 |
+
oriId = (torch.floor(max_wid / self.num_speed)).detach()
|
210 |
+
|
211 |
+
tmp_weight = self.motion_weight.permute(2, 3, 1, 0)
|
212 |
+
# change the dimension of matrix from (ori_num, speed_num, height, width) to (h,w, speed_num, ori_num)
|
213 |
+
tmp_weight = torch.reshape(tmp_weight, [self.dvs_h, self.dvs_w, self.num_ori, self.num_speed])
|
214 |
+
tmp_weight = tmp_weight.permute(0, 1, 3, 2)
|
215 |
+
tmp_weight_x = torch.matmul(tmp_weight, self.ori_x)
|
216 |
+
tmp_weight_y = torch.matmul(tmp_weight, self.ori_y)
|
217 |
+
# tmp_weight_x = torch.reshape(torch.mm(tmp_weight, self.ori_x), [self.dvs_h, self.dvs_w, self.num_speed])
|
218 |
+
# tmp_weight_y = torch.reshape(torch.mm(tmp_weight, self.ori_y), [self.dvs_h, self.dvs_w, self.num_speed])
|
219 |
+
|
220 |
+
max_w = torch.squeeze(max_w)
|
221 |
+
fired_spk_index2d = torch.where(torch.logical_and(spikes != 0, max_w > 0))
|
222 |
+
|
223 |
+
# speedId = speedId[fired_spk_index2d].cpu().numpy()
|
224 |
+
# oriId = oriId[fired_spk_index2d].int().cpu().numpy()
|
225 |
+
#
|
226 |
+
# dx = -1 * self.ori[oriId, 1] * self.speed[speedId]
|
227 |
+
# dy = -1 * self.ori[oriId, 0] * self.speed[speedId]
|
228 |
+
|
229 |
+
tmp_weight_x = torch.mean(tmp_weight_x, dim=2)
|
230 |
+
tmp_weight_y = torch.mean(tmp_weight_y, dim=2)
|
231 |
+
tmp_weight_x = torch.squeeze(tmp_weight_x)
|
232 |
+
tmp_weight_y = torch.squeeze(tmp_weight_y)
|
233 |
+
|
234 |
+
dx = tmp_weight_x[fired_spk_index2d]
|
235 |
+
dy = tmp_weight_y[fired_spk_index2d]
|
236 |
+
|
237 |
+
motion_vector_layer1[fired_spk_index2d[0], fired_spk_index2d[1], 0] = dx
|
238 |
+
motion_vector_layer1[fired_spk_index2d[0], fired_spk_index2d[1], 1] = dy
|
239 |
+
dy_numpy = dy.cpu().numpy()
|
240 |
+
dx_numpy = dx.cpu().numpy()
|
241 |
+
|
242 |
+
# motion_vector_layer1[fired_spk_index2d[0], fired_spk_index2d[1], 0] = torch.from_numpy(dx.astype('float32')).to(self.device)
|
243 |
+
# motion_vector_layer1[fired_spk_index2d[0], fired_spk_index2d[1], 1] = torch.from_numpy(dy.astype('float32')).to(self.device)
|
244 |
+
# dx_numpy = dx
|
245 |
+
# dy_numpy = dy
|
246 |
+
|
247 |
+
rotAng = np.arctan2(-dy_numpy, dx_numpy) * 180 / np.pi + 180
|
248 |
+
rotAng[np.where(rotAng == 360)] = 0
|
249 |
+
tmp_motion = np.floor(rotAng / (360 / 8))
|
250 |
+
# tmp_motion[np.where(tmp_motion == 8)] = 0
|
251 |
+
track_voltage = torch.zeros(self.num_ori, self.dvs_h, self.dvs_w)
|
252 |
+
track_voltage[tmp_motion, fired_spk_index2d[0], fired_spk_index2d[1]] = 1
|
253 |
+
|
254 |
+
track_voltage = torch.unsqueeze(track_voltage, 1)
|
255 |
+
track_voltage = track_voltage.to(self.device)
|
256 |
+
track_voltage = torch.squeeze(self.inhb_kernel(track_voltage))
|
257 |
+
max_v, max_vid = torch.max(track_voltage, dim=0)
|
258 |
+
|
259 |
+
fired_layer2_index = torch.where(
|
260 |
+
torch.logical_and(max_v >= self.inhb_threshold, torch.logical_and(spikes != 0, max_w > 0)))
|
261 |
+
max_motion = torch.zeros(self.dvs_h, self.dvs_w, dtype=torch.int64)
|
262 |
+
max_motion_layer1 = torch.zeros(self.dvs_h, self.dvs_w, dtype=torch.int64)
|
263 |
+
max_motion_layer1 = max_motion_layer1.to(self.device)
|
264 |
+
|
265 |
+
motion_vector_max = torch.zeros(self.dvs_h, self.dvs_w, 2, dtype=torch.float32)
|
266 |
+
max_motion = max_motion.to(self.device)
|
267 |
+
motion_vector_max = motion_vector_max.to(self.device)
|
268 |
+
|
269 |
+
max_motion[fired_layer2_index] = max_vid[fired_layer2_index].detach() + 1
|
270 |
+
motion_tensor = torch.from_numpy(tmp_motion + 1).to(self.device)
|
271 |
+
max_motion_layer1[fired_spk_index2d[0], fired_spk_index2d[1]] = motion_tensor.long()
|
272 |
+
max_motion_layer1[max_motion == 0] = 0
|
273 |
+
|
274 |
+
# 1. find the difference between m1 and mc motion
|
275 |
+
tmp_vid = max_vid[fired_layer2_index].cpu().detach().numpy()
|
276 |
+
if len(tmp_vid) != 0:
|
277 |
+
motion_vector_max[fired_layer2_index] = motion_vector_layer1[fired_layer2_index].detach()
|
278 |
+
loser_pattern_index = torch.where(torch.logical_and(max_motion != 0, max_motion_layer1 != max_motion))
|
279 |
+
fired2_index_x = loser_pattern_index[0]
|
280 |
+
fired2_index_y = loser_pattern_index[1]
|
281 |
+
voltage_block = max_v[None, None, :, :]
|
282 |
+
voltage_block = F.pad(voltage_block, (self.padding_width, self.padding_width, self.padding_width, self.padding_width),
|
283 |
+
mode='constant', value=0)
|
284 |
+
voltage_block = F.unfold(voltage_block, (self.inh_size, self.inh_size))
|
285 |
+
voltage_block = voltage_block.reshape([1, self.inh_size*self.inh_size, self.dvs_h, self.dvs_w])
|
286 |
+
offset_pattern = torch.argmax(voltage_block, dim=1)
|
287 |
+
offset_pattern = torch.squeeze(offset_pattern)
|
288 |
+
offset_pattern_loser = offset_pattern[fired2_index_x, fired2_index_y]
|
289 |
+
offset_x = offset_pattern_loser / self.inh_size - self.padding_width
|
290 |
+
offset_y = torch.fmod(offset_pattern_loser, self.inh_size) - self.padding_width
|
291 |
+
offset_x = offset_x.int()
|
292 |
+
offset_y = offset_y.int()
|
293 |
+
motion_vector_max[fired2_index_x, fired2_index_y, :] = motion_vector_max[fired2_index_x + offset_x,
|
294 |
+
fired2_index_y + offset_y, :]
|
295 |
+
|
296 |
+
# for i_vector in range(len(fired2_index_x)):
|
297 |
+
#
|
298 |
+
# if fired2_index_x[i_vector] - self.padding_width < 0:
|
299 |
+
# x_begin = 0
|
300 |
+
# else:
|
301 |
+
# x_begin = fired2_index_x[i_vector] - self.padding_width
|
302 |
+
#
|
303 |
+
# if fired2_index_x[i_vector] + self.padding_width >= self.dvs_h:
|
304 |
+
# x_end = self.dvs_h
|
305 |
+
# else:
|
306 |
+
# x_end = fired2_index_x[i_vector] + self.padding_width
|
307 |
+
#
|
308 |
+
# if fired2_index_y[i_vector] - self.padding_width < 0:
|
309 |
+
# y_begin = 0
|
310 |
+
# else:
|
311 |
+
# y_begin = fired2_index_y[i_vector] - self.padding_width
|
312 |
+
#
|
313 |
+
# if fired2_index_y[i_vector] + self.padding_width >= self.dvs_w:
|
314 |
+
# y_end = self.dvs_w
|
315 |
+
# else:
|
316 |
+
# y_end = fired2_index_y[i_vector] + self.padding_width
|
317 |
+
#
|
318 |
+
# winner_motion = max_motion[fired2_index_x[i_vector], fired2_index_y[i_vector]]
|
319 |
+
# motion_id_block = max_motion_layer1[x_begin:x_end, y_begin:y_end]
|
320 |
+
# motion_block = motion_vector_max[x_begin:x_end, y_begin:y_end, :]
|
321 |
+
# motion_voltage_block = max_v[x_begin:x_end, y_begin:y_end]
|
322 |
+
# winner_id = torch.where(motion_id_block == winner_motion)
|
323 |
+
# if len(winner_id[0]) > 0:
|
324 |
+
# winner_mv = motion_block[winner_id[0], winner_id[1], :]
|
325 |
+
# winner_voltage = motion_voltage_block[winner_id[0], winner_id[1]]
|
326 |
+
# motion_vector_max[fired2_index_x[i_vector], fired2_index_y[i_vector], :] = \
|
327 |
+
# winner_mv[torch.argmax(winner_voltage), :]
|
328 |
+
|
329 |
+
# 2. replace the loser motion pattern
|
330 |
+
|
331 |
+
if visualize is True:
|
332 |
+
Image_layer1 = torch.zeros(3, self.dvs_h, self.dvs_w).to(self.device)
|
333 |
+
Image_layer1[:, fired_spk_index2d[0], fired_spk_index2d[1]] = self.cc_motion[:, tmp_motion]
|
334 |
+
|
335 |
+
Image_layer2 = torch.zeros(3, self.dvs_h, self.dvs_w).to(self.device)
|
336 |
+
Image_layer2[:, fired_layer2_index[0], fired_layer2_index[1]] = self.cc_motion[:, tmp_vid]
|
337 |
+
|
338 |
+
self.logger.add_image('motion_estimation/M1 estimation', Image_layer1, timestamp)
|
339 |
+
self.logger.add_image('motion_estimation/MC estimation', Image_layer2, timestamp)
|
340 |
+
|
341 |
+
# track_voltage.to(self.device_cpu)
|
342 |
+
|
343 |
+
del input_spike, fired_spk_index2d, fired_layer2_index
|
344 |
+
del track_voltage, dx, dy
|
345 |
+
torch.cuda.empty_cache()
|
346 |
+
|
347 |
+
return max_motion, motion_vector_max, motion_vector_layer1
|
snnTracker/spkProc/tracking/snn_tracker.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2023/7/16 20:23
|
3 |
+
# @Author : Yajing Zheng
|
4 |
+
# @Email: [email protected]
|
5 |
+
# @File : snn_tracker.py
|
6 |
+
import os, sys
|
7 |
+
sys.path.append('../..')
|
8 |
+
import time
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from spkProc.filters.stp_filters_torch import STPFilter
|
13 |
+
# from filters import stpFilter
|
14 |
+
from spkProc.detection.attention_select import SaccadeInput
|
15 |
+
from spkProc.motion.motion_detection import motion_estimation
|
16 |
+
from spkProc.detection.stdp_clustering import stdp_cluster
|
17 |
+
from utils import NumpyEncoder
|
18 |
+
from collections import namedtuple
|
19 |
+
import json
|
20 |
+
import cv2
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
trajectories = namedtuple('trajectories', ['id', 'x', 'y', 't', 'color'])
|
24 |
+
|
25 |
+
class SNNTracker:
|
26 |
+
|
27 |
+
def __init__(self, spike_h, spike_w, device, attention_size=20, diff_time=1, **STPargs):
|
28 |
+
self.spike_h = spike_h
|
29 |
+
self.spike_w = spike_w
|
30 |
+
self.device = device
|
31 |
+
|
32 |
+
# self.stp_filter = STPFilter(spike_h, spike_w, device)
|
33 |
+
if STPargs is not None:
|
34 |
+
self.stp_filter = STPFilter(spike_h, spike_w, device, diff_time, **STPargs)
|
35 |
+
else:
|
36 |
+
self.stp_filter = STPFilter(spike_h, spike_w, device, diff_time)
|
37 |
+
# self.stp_filter = stpFilter()
|
38 |
+
self.attention_size = attention_size
|
39 |
+
self.object_detection = SaccadeInput(spike_h, spike_w, box_size=self.attention_size, device=device)
|
40 |
+
from tensorboardX import SummaryWriter
|
41 |
+
logger = SummaryWriter(log_dir='data/log_pkuvidar')
|
42 |
+
self.motion_estimator = motion_estimation(spike_h, spike_w, device, logger=logger)
|
43 |
+
# gpu_tracker.track() # run function between the code line where uses GPU
|
44 |
+
|
45 |
+
self.object_cluster = stdp_cluster(spike_h, spike_w, box_size=self.attention_size, device=device)
|
46 |
+
|
47 |
+
# self.timestamps = spikes.shape[0]
|
48 |
+
# self.filterd_spikes = np.zeros([self.timestamps, self.spike_h, self.spike_w], np.uint8)
|
49 |
+
self.calibration_time = 150
|
50 |
+
self.timestamps = 0
|
51 |
+
self.trajectories = {}
|
52 |
+
self.filterd_spikes = []
|
53 |
+
|
54 |
+
def calibrate_motion(self, spikes, calibration_time=None):
|
55 |
+
|
56 |
+
if calibration_time is None:
|
57 |
+
calibration_time = self.calibration_time
|
58 |
+
else:
|
59 |
+
self.calibration_time = calibration_time
|
60 |
+
|
61 |
+
print('begin calibrate..')
|
62 |
+
for t in range(calibration_time):
|
63 |
+
input_spk = torch.from_numpy(spikes[t, :, :]).to(self.device)
|
64 |
+
self.stp_filter.update_dynamics(t, input_spk)
|
65 |
+
self.timestamps += 1
|
66 |
+
|
67 |
+
def get_results(self, spikes, res_filepath, mov_writer=None, save_video=False):
|
68 |
+
|
69 |
+
result_file = open(res_filepath, 'a+')
|
70 |
+
|
71 |
+
timestamps = spikes.shape[0]
|
72 |
+
total_time = 0
|
73 |
+
predict_kwargs = {'spike_h': self.spike_h, 'spike_w': self.spike_w, 'device': self.device}
|
74 |
+
|
75 |
+
for t in tqdm(range(timestamps), desc=f'Saving tracking results to {str(result_file)}'):
|
76 |
+
try:
|
77 |
+
input_spk = torch.from_numpy(spikes[t, :, :]).to(self.device)
|
78 |
+
self.stp_filter.update_dynamics(self.timestamps, input_spk)
|
79 |
+
|
80 |
+
self.stp_filter.local_connect(self.stp_filter.filter_spk)
|
81 |
+
# self.filterd_spikes[t, :, :] = self.stp_filter.lif_spk.cpu().detach().numpy()
|
82 |
+
|
83 |
+
self.object_detection.update_dnf(self.stp_filter.lif_spk)
|
84 |
+
attentionBox, attentionInput = self.object_detection.get_attention_location(self.stp_filter.lif_spk)
|
85 |
+
# attentionInput = attentionInput.to(self.device)
|
86 |
+
num_box = attentionBox.shape[0]
|
87 |
+
self.motion_estimator.stdp_tracking(self.stp_filter.lif_spk)
|
88 |
+
|
89 |
+
motion_id, motion_vector, _ = self.motion_estimator.local_wta(self.stp_filter.lif_spk, self.timestamps, visualize=True)
|
90 |
+
# gpu_tracker.track() # run function between the code line where uses GPU
|
91 |
+
|
92 |
+
predict_fire, sw, bw = self.object_cluster.update_weight(attentionInput)
|
93 |
+
|
94 |
+
predict_object = self.object_cluster.detect_object(predict_fire, attentionBox, motion_id, motion_vector, **predict_kwargs)
|
95 |
+
|
96 |
+
# visualize_weights(sw, 'before update tracks', t)
|
97 |
+
|
98 |
+
sw, bw = self.object_cluster.update_tracks(predict_object, sw, bw, self.timestamps)
|
99 |
+
|
100 |
+
self.object_cluster.synaptic_weight = sw.detach().clone()
|
101 |
+
self.object_cluster.bias_weight = bw.detach().clone()
|
102 |
+
|
103 |
+
dets = torch.zeros((num_box, 6), dtype=torch.int)
|
104 |
+
for i_box, bbox in enumerate(attentionBox):
|
105 |
+
dets[i_box, :] = torch.tensor([bbox[0], bbox[1], bbox[2], bbox[3], 1, 1])
|
106 |
+
|
107 |
+
track_ids = []
|
108 |
+
if save_video:
|
109 |
+
track_frame = self.stp_filter.lif_spk.cpu().numpy()
|
110 |
+
track_frame = (track_frame * 255).astype(np.uint8)
|
111 |
+
# track_frame = np.transpose(track_frame, (1, 2, 0))
|
112 |
+
|
113 |
+
# track_frame = np.tile(track_frame, (3, 1, 1))
|
114 |
+
# track_frame = np.squeeze(track_frame)
|
115 |
+
track_frame = cv2.cvtColor(track_frame, cv2.COLOR_GRAY2BGR)
|
116 |
+
|
117 |
+
for i_box in range(attentionBox.shape[0]):
|
118 |
+
tmp_box = attentionBox[i_box, :]
|
119 |
+
cv2.rectangle(track_frame, (int(tmp_box[1]), int(tmp_box[0])), (int(tmp_box[3]), int(tmp_box[2])),
|
120 |
+
(int(0), int(0), int(255)), 2)
|
121 |
+
|
122 |
+
for i_box in range(self.object_cluster.K2):
|
123 |
+
if self.object_cluster.tracks[i_box].visible == 1:
|
124 |
+
tmp_box = self.object_cluster.tracks[i_box].bbox.numpy()
|
125 |
+
pred_box = self.object_cluster.tracks[i_box].predbox.numpy()
|
126 |
+
id = self.object_cluster.tracks[i_box].id
|
127 |
+
color = self.object_cluster.tracks[i_box].color
|
128 |
+
|
129 |
+
# update the trajectories
|
130 |
+
mid_y = (tmp_box[0, 0] + tmp_box[0, 2]) / 2 # height
|
131 |
+
mid_x = (tmp_box[0, 1] + tmp_box[0, 3]) / 2 # width
|
132 |
+
box_w = int(tmp_box[0, 3] - tmp_box[0, 1])
|
133 |
+
box_h = int(tmp_box[0,2] - tmp_box[0, 0])
|
134 |
+
print('%d,%d,%.2f,%.2f,%.2f,%.2f,1,-1,-1,-1' % (
|
135 |
+
self.timestamps, id, tmp_box[0, 1], tmp_box[0, 0], box_w, box_h), file=result_file)
|
136 |
+
|
137 |
+
if id not in self.trajectories:
|
138 |
+
self.trajectories[id] = trajectories(int(id), [], [], [], 255 * np.random.rand(1, 3))
|
139 |
+
self.trajectories[id].x.append(mid_x)
|
140 |
+
self.trajectories[id].y.append(mid_y)
|
141 |
+
self.trajectories[id].t.append(self.timestamps)
|
142 |
+
|
143 |
+
else:
|
144 |
+
self.trajectories[id].x.append(mid_x)
|
145 |
+
self.trajectories[id].y.append(mid_y)
|
146 |
+
self.trajectories[id].t.append(self.timestamps)
|
147 |
+
# the detection results
|
148 |
+
|
149 |
+
if save_video:
|
150 |
+
cv2.rectangle(track_frame, (int(tmp_box[0, 1]), int(tmp_box[0, 0])),
|
151 |
+
(int(tmp_box[0, 3]), int(tmp_box[0, 2])),
|
152 |
+
(int(color[0, 0]), int(color[0, 1]), int(color[0, 2])), 2)
|
153 |
+
|
154 |
+
# # the predicted results
|
155 |
+
# cv2.rectangle(track_frame, (int(pred_box[0, 1]), int(pred_box[0, 0])),
|
156 |
+
# (int(pred_box[0, 3]), int(pred_box[0, 2])), (int(0), int(0), int(255)), 2)
|
157 |
+
|
158 |
+
# the label box
|
159 |
+
cv2.rectangle(track_frame, (int(tmp_box[0, 1]), int(tmp_box[0, 0] - 35)),
|
160 |
+
(int(tmp_box[0, 1] + 60), int(tmp_box[0, 0])),
|
161 |
+
(int(color[0, 0]), int(color[0, 1]), int(color[0, 2])), -1)
|
162 |
+
if self.object_cluster.tracks[i_box].unvisible_count > 0:
|
163 |
+
show_text = 'predict' + str(id)
|
164 |
+
else:
|
165 |
+
show_text = 'object' + str(id)
|
166 |
+
cv2.putText(track_frame, show_text, (int(tmp_box[0, 1]), int(tmp_box[0, 0] - 10)),
|
167 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
168 |
+
1, (255, 255, 255), 2)
|
169 |
+
|
170 |
+
if save_video:
|
171 |
+
cv2.putText(track_frame, str(int(self.timestamps)),
|
172 |
+
(10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 255), 2)
|
173 |
+
mov_writer.write(track_frame)
|
174 |
+
self.timestamps += 1
|
175 |
+
|
176 |
+
except RuntimeError as exception:
|
177 |
+
if "out of memory" in str(exception):
|
178 |
+
print('WARNING: out of memory')
|
179 |
+
if hasattr(torch.cuda, 'empty_cache'):
|
180 |
+
torch.cuda.empty_cache()
|
181 |
+
else:
|
182 |
+
raise exception
|
183 |
+
|
184 |
+
print('Total tracking took: %.3f seconds for %d timestamps spikes' %
|
185 |
+
(total_time, self.timestamps - self.calibration_time))
|
186 |
+
|
187 |
+
# if save_video:
|
188 |
+
# mov_writer.release()
|
189 |
+
# cv2.destroyAllWindows()
|
190 |
+
|
191 |
+
result_file.close()
|
192 |
+
|
193 |
+
def save_trajectory(self, results_dir, data_name):
|
194 |
+
trajectories_filename = os.path.join(results_dir, data_name + '_py.json')
|
195 |
+
mat_trajectories_filename = 'results/' + data_name + '.json'
|
196 |
+
track_box_filename = 'results/' + data_name + '_bbox.json'
|
197 |
+
|
198 |
+
if os.path.exists(trajectories_filename):
|
199 |
+
os.remove(trajectories_filename)
|
200 |
+
|
201 |
+
if os.path.exists(mat_trajectories_filename):
|
202 |
+
os.remove(mat_trajectories_filename)
|
203 |
+
|
204 |
+
if os.path.exists(track_box_filename):
|
205 |
+
os.remove(track_box_filename)
|
206 |
+
|
207 |
+
for i_traj in range(self.object_cluster.K2):
|
208 |
+
tmp_traj = self.object_cluster.trajectories[i_traj]
|
209 |
+
tmp_bbox = self.object_cluster.tracks_bbox[i_traj]
|
210 |
+
|
211 |
+
traj_json_string = json.dumps(tmp_traj._asdict(), cls=NumpyEncoder)
|
212 |
+
bbox_json_string = json.dumps(tmp_bbox._asdict(), cls=NumpyEncoder)
|
213 |
+
|
214 |
+
with open(mat_trajectories_filename, 'a+') as f:
|
215 |
+
f.write(traj_json_string)
|
216 |
+
|
217 |
+
with open(track_box_filename, 'a+') as f:
|
218 |
+
f.write(bbox_json_string)
|
219 |
+
|
220 |
+
num_len = len(self.trajectories)
|
221 |
+
for i_traj in self.trajectories:
|
222 |
+
traj_json_string = json.dumps(self.trajectories[i_traj]._asdict(), cls=NumpyEncoder)
|
223 |
+
|
224 |
+
with open(trajectories_filename, 'a+') as f:
|
225 |
+
f.write(traj_json_string)
|
226 |
+
|
227 |
+
f.write('\n')
|
snnTracker/test_motion_detection.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from spkProc.tracking.snn_tracker import SNNTracker
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
|
8 |
+
def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True):
|
9 |
+
'''
|
10 |
+
output: <class 'numpy.ndarray'> (frame_cnt, height, width) {0,1} float32
|
11 |
+
'''
|
12 |
+
array = np.fromfile(filename, dtype=np.uint8)
|
13 |
+
|
14 |
+
len_per_frame = height * width // 8
|
15 |
+
framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame
|
16 |
+
|
17 |
+
spikes = []
|
18 |
+
for i in range(framecnt):
|
19 |
+
compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame]
|
20 |
+
blist = []
|
21 |
+
for b in range(8):
|
22 |
+
blist.append(np.right_shift(np.bitwise_and(
|
23 |
+
compr_frame, np.left_shift(1, b)), b))
|
24 |
+
|
25 |
+
frame_ = np.stack(blist).transpose()
|
26 |
+
frame_ = frame_.reshape((height, width), order='C')
|
27 |
+
if reverse_spike:
|
28 |
+
frame_ = np.flipud(frame_)
|
29 |
+
spikes.append(frame_)
|
30 |
+
|
31 |
+
return np.array(spikes).astype(np.float32)
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
def detect_motion(spikes, calibration_frames=200, device=None):
|
36 |
+
"""
|
37 |
+
使用SNN进行运动目标检测
|
38 |
+
Args:
|
39 |
+
spikes: shape为[frames, height, width]的脉冲数据
|
40 |
+
calibration_frames: 用于校准的帧数
|
41 |
+
device: 运行设备(CPU/GPU)
|
42 |
+
Returns:
|
43 |
+
motion_mask: 第calibration_frames帧的运动目标掩码
|
44 |
+
"""
|
45 |
+
if device is None:
|
46 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
47 |
+
|
48 |
+
spike_h, spike_w = spikes.shape[1:]
|
49 |
+
|
50 |
+
# 初始化SNN跟踪器
|
51 |
+
spike_tracker = SNNTracker(spike_h, spike_w, device, attention_size=15)
|
52 |
+
|
53 |
+
# 使用前calibration_frames帧进行校准
|
54 |
+
calibration_spikes = spikes[:calibration_frames]
|
55 |
+
spike_tracker.calibrate_motion(calibration_spikes, calibration_frames)
|
56 |
+
|
57 |
+
# 获取第calibration_frames帧的运动检测结果
|
58 |
+
target_frame = spikes[calibration_frames]
|
59 |
+
target_frame = torch.from_numpy(target_frame).to(device)
|
60 |
+
# target_frame = target_frame.reshape(1, 1, spike_h, spike_w)
|
61 |
+
|
62 |
+
# 获取运动检测结果
|
63 |
+
motion_id, motion_vector, _ = spike_tracker.motion_estimator.local_wta(target_frame, calibration_frames)
|
64 |
+
|
65 |
+
# 生成运动掩码
|
66 |
+
motion_mask = (motion_id > 0).cpu().numpy()
|
67 |
+
|
68 |
+
return motion_mask
|
69 |
+
|
70 |
+
def spikes_to_tfi(spk_seq):
|
71 |
+
n, h, w = spk_seq.shape
|
72 |
+
last_index = np.zeros((1, h, w))
|
73 |
+
cur_index = np.zeros((1, h, w))
|
74 |
+
c_frames = np.zeros_like(spk_seq).astype(np.float64)
|
75 |
+
for i in range(n - 1):
|
76 |
+
last_index = cur_index
|
77 |
+
cur_index = spk_seq[i+1,:,:] * (i + 1) + (1 - spk_seq[i+1,:,:]) * last_index
|
78 |
+
c_frames[i,:,:] = cur_index - last_index
|
79 |
+
last_frame = c_frames[n-1:,:]
|
80 |
+
last_frame[last_frame==0] = n
|
81 |
+
c_frames[n-1,:,:] = last_frame
|
82 |
+
last_interval = n * np.ones((1, h, w))
|
83 |
+
for i in range(n - 2, -1, -1):
|
84 |
+
last_interval = spk_seq[i+1,:,:] * c_frames[i,:,:] + (1 - spk_seq[i+1,:,:]) * last_interval
|
85 |
+
tmp_frame = np.expand_dims(c_frames[i,:,:], 0)
|
86 |
+
tmp_frame[tmp_frame==0] = last_interval[tmp_frame==0]
|
87 |
+
c_frames[i] = tmp_frame
|
88 |
+
return 1.0 / c_frames
|
89 |
+
|
90 |
+
def detect_object(spikes, calibration_frames=200, device=None):
|
91 |
+
"""
|
92 |
+
使用SNN进行目标检测
|
93 |
+
Args:
|
94 |
+
spikes: shape为[frames, height, width]的脉冲数据
|
95 |
+
calibration_frames: 用于校准的帧数
|
96 |
+
device: 运行设备(CPU/GPU)
|
97 |
+
Returns:
|
98 |
+
object_mask: 第calibration_frames帧的目标掩码
|
99 |
+
"""
|
100 |
+
if device is None:
|
101 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
102 |
+
|
103 |
+
spike_h, spike_w = spikes.shape[1:]
|
104 |
+
|
105 |
+
# 初始化SNN跟踪器
|
106 |
+
spike_tracker = SNNTracker(spike_h, spike_w, device, attention_size=15)
|
107 |
+
spike_tracker.object_cluster.K2 = 4
|
108 |
+
|
109 |
+
# 使用前calibration_frames帧进行校准
|
110 |
+
calibration_spikes = spikes[:calibration_frames]
|
111 |
+
spike_tracker.calibrate_motion(calibration_spikes, calibration_frames)
|
112 |
+
|
113 |
+
# 获取第calibration_frames帧的目标检测结果
|
114 |
+
target_frame = spikes[calibration_frames: calibration_frames + 200]
|
115 |
+
print(target_frame.shape)
|
116 |
+
# target_frame = target_frame.reshape(1, 1, spike_h, spike_w)
|
117 |
+
|
118 |
+
# 获取目标检测结果
|
119 |
+
save_filename = "testtest.avi"
|
120 |
+
mov = cv2.VideoWriter(save_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (400, 250))
|
121 |
+
spike_tracker.get_results(target_frame, save_filename, mov, save_video=True)
|
122 |
+
|
123 |
+
mov.release()
|
124 |
+
cv2.destroyAllWindows()
|
125 |
+
return 0
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
height = 250
|
130 |
+
width = 400
|
131 |
+
spikes = load_vidar_dat("0.dat", width=width, height=height)
|
132 |
+
for n in range(1,10):
|
133 |
+
tmp_spikes = load_vidar_dat(f"{n}.dat", width=width, height=height)
|
134 |
+
spikes = np.concatenate((spikes, tmp_spikes), axis=0)
|
135 |
+
print(spikes.shape)
|
136 |
+
|
137 |
+
spikes = spikes[::10]
|
138 |
+
|
139 |
+
motion_mask = detect_object(spikes, calibration_frames=200)
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
tfi = spikes_to_tfi(spikes)
|
144 |
+
# 保存重建的视频
|
145 |
+
save_recon_filename = "tfi.avi"
|
146 |
+
recon_mov = cv2.VideoWriter(save_recon_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (width, height))
|
147 |
+
|
148 |
+
for frame in tfi:
|
149 |
+
frame_norm = (frame * 255).astype(np.uint8)
|
150 |
+
frame_rgb = cv2.cvtColor(frame_norm, cv2.COLOR_GRAY2BGR)
|
151 |
+
recon_mov.write(frame_rgb)
|
152 |
+
|
153 |
+
recon_mov.release()
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
# 检测运动目标
|
158 |
+
# motion_mask = detect_motion(spikes, calibration_frames=200)
|
159 |
+
# print(f"Motion mask shape: {motion_mask.shape}")
|
160 |
+
# print(f"Number of motion pixels: {motion_mask.sum()}")
|
161 |
+
|
162 |
+
# 可视化运动目标检测结果
|
163 |
+
# plt.figure(figsize=(10, 5))
|
164 |
+
# plt.subplot(1, 2, 1)
|
165 |
+
# plt.imshow(spikes[200], cmap='gray')
|
166 |
+
# plt.title("Input frame")
|
167 |
+
# plt.axis('off')
|
168 |
+
# plt.subplot(1, 2, 2)
|
169 |
+
# plt.imshow(motion_mask, cmap='gray')
|
170 |
+
# plt.title("Motion mask")
|
171 |
+
# plt.axis('off')
|
172 |
+
# plt.show()
|
173 |
+
|
174 |
+
|
175 |
+
# 计算原始脉冲图和运动掩码之间的差异
|
176 |
+
# spike_frame = spikes[200] # 获取第200帧脉冲图
|
177 |
+
|
178 |
+
# # 计算差异指标
|
179 |
+
# pixel_diff = np.logical_xor(spike_frame > 0, motion_mask).sum()
|
180 |
+
# total_pixels = height * width
|
181 |
+
# diff_ratio = pixel_diff / total_pixels
|
182 |
+
|
183 |
+
# print("\n运动检测结果分析:")
|
184 |
+
# print(f"原始脉冲图中的活跃像素数: {(spike_frame > 0).sum()}")
|
185 |
+
# print(f"运动掩码中的运动像素数: {motion_mask.sum()}")
|
186 |
+
# print(f"不一致的像素数: {pixel_diff}")
|
187 |
+
# print(f"像素差异比例: {diff_ratio:.2%}")
|
188 |
+
|
189 |
+
# # 可视化差异
|
190 |
+
# plt.figure(figsize=(10, 5))
|
191 |
+
# plt.subplot(1, 2, 1)
|
192 |
+
# plt.imshow(np.logical_xor(spike_frame > 0, motion_mask), cmap='gray')
|
193 |
+
# plt.title("Difference map (white indicates inconsistency)")
|
194 |
+
# plt.axis('off')
|
195 |
+
|
196 |
+
# plt.subplot(1, 2, 2)
|
197 |
+
# plt.imshow(spike_frame > 0, cmap='gray', alpha=0.5)
|
198 |
+
# plt.imshow(motion_mask, cmap='Reds', alpha=0.5)
|
199 |
+
# plt.title("Overlay (Red: Motion mask, Gray: Original spikes)")
|
200 |
+
# plt.axis('off')
|
201 |
+
# plt.show()
|
snnTracker/test_snntracker copy.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2024/12/05 20:17
|
3 |
+
# @Author : Yajing Zheng
|
4 |
+
# @Email: [email protected]
|
5 |
+
# @File : test_snntracker.py
|
6 |
+
import os, sys
|
7 |
+
sys.path.append("..")
|
8 |
+
import path
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from spkData.load_dat import data_parameter_dict, SpikeStream
|
12 |
+
from pprint import pprint
|
13 |
+
import torch
|
14 |
+
from spkProc.tracking.snn_tracker import SNNTracker
|
15 |
+
from utils import vis_trajectory
|
16 |
+
from visualization.get_video import obtain_mot_video
|
17 |
+
import cv2
|
18 |
+
# from tracking_mot import TrackingMetrics
|
19 |
+
|
20 |
+
from visualization.get_video import obtain_detection_video
|
21 |
+
|
22 |
+
# change the path to where you put the datasets
|
23 |
+
test_scene = "0"
|
24 |
+
# data_filename = 'motVidarReal2020/rotTrans'
|
25 |
+
data_filename = test_scene
|
26 |
+
label_type = 'tracking'
|
27 |
+
para_dict = data_parameter_dict(data_filename, label_type)
|
28 |
+
pprint(para_dict)
|
29 |
+
vidarSpikes = SpikeStream(**para_dict)
|
30 |
+
|
31 |
+
# block_len = 2000
|
32 |
+
# spikes = vidarSpikes.get_block_spikes(begin_idx=0, block_len=block_len)
|
33 |
+
spikes = vidarSpikes.get_spike_matrix()
|
34 |
+
pprint(spikes.shape)
|
35 |
+
|
36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
|
38 |
+
calibration_time = 150
|
39 |
+
filename = path.split_path_into_pieces(data_filename)
|
40 |
+
result_filename = filename[-1] + '_snn.txt'
|
41 |
+
if not os.path.exists('results'):
|
42 |
+
os.makedirs('results')
|
43 |
+
tracking_file = os.path.join('results', result_filename)
|
44 |
+
if os.path.exists(tracking_file):
|
45 |
+
os.remove(tracking_file)
|
46 |
+
|
47 |
+
# stp_params = {'filterThr': 0.12, # filter threshold
|
48 |
+
# 'voltageMin': -10,
|
49 |
+
# 'lifThr': 3}
|
50 |
+
spike_tracker = SNNTracker(para_dict.get('spike_h'), para_dict.get('spike_w'), device, attention_size=15)
|
51 |
+
spike_tracker.object_cluster.K2 = 4
|
52 |
+
# total_spikes = spikes
|
53 |
+
|
54 |
+
# using stp filter to filter out static spikes
|
55 |
+
spike_tracker.calibrate_motion(spikes, calibration_time)
|
56 |
+
# start tracking
|
57 |
+
track_videoName = tracking_file.replace('txt', 'avi')
|
58 |
+
mov = cv2.VideoWriter(track_videoName, cv2.VideoWriter_fourcc(*'MJPG'), 30, (para_dict.get('spike_w'), para_dict.get('spike_h')))
|
59 |
+
spike_tracker.get_results(spikes[calibration_time:], tracking_file, mov, save_video=True)
|
60 |
+
|
61 |
+
data_name = test_scene
|
62 |
+
trajectories_filename = os.path.join('results', data_name + '_py.json')
|
63 |
+
visTraj_filename = os.path.join('results', data_name + '.png')
|
64 |
+
|
65 |
+
spike_tracker.save_trajectory('results', data_name)
|
66 |
+
vis_trajectory(trajectories_filename, visTraj_filename, **para_dict)
|
67 |
+
# measure the multi-object tracking performance
|
68 |
+
# metrics = TrackingMetrics(tracking_file, **para_dict)
|
69 |
+
# metrics.get_results()
|
70 |
+
#
|
71 |
+
# block_len = total_spikes.shape[0]
|
72 |
+
mov.release()
|
73 |
+
cv2.destroyAllWindows()
|
74 |
+
# # visualize the tracking results to a video
|
75 |
+
# video_filename = os.path.join('results', filename[-1] + '_mot.avi')
|
76 |
+
# obtain_mot_video(spike_tracker.filterd_spikes, video_filename, tracking_file, **para_dict)
|
77 |
+
# obtain_detection_video(total_spikes, video_filename, tracking_file, evaluate_seq_len=evaluate_seq_len, **para_dict)
|
snnTracker/test_snntracker.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2024/12/05 20:17
|
3 |
+
# @Author : Yajing Zheng
|
4 |
+
# @Email: [email protected]
|
5 |
+
# @File : test_snntracker.py
|
6 |
+
import os, sys
|
7 |
+
sys.path.append("..")
|
8 |
+
import path
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from spkData.load_dat import data_parameter_dict, SpikeStream
|
12 |
+
from pprint import pprint
|
13 |
+
import torch
|
14 |
+
from spkProc.tracking.snn_tracker import SNNTracker
|
15 |
+
from utils import vis_trajectory
|
16 |
+
from visualization.get_video import obtain_mot_video
|
17 |
+
import cv2
|
18 |
+
from tracking_mot import TrackingMetrics
|
19 |
+
|
20 |
+
from visualization.get_video import obtain_detection_video
|
21 |
+
|
22 |
+
# change the path to where you put the datasets
|
23 |
+
test_scene = ['spike59', 'rotTrans', 'cplCam', 'cpl1', 'badminton', 'ball']
|
24 |
+
# data_filename = 'motVidarReal2020/rotTrans'
|
25 |
+
scene_idx = 2
|
26 |
+
data_filename = 'motVidarReal2020/' + test_scene[scene_idx]
|
27 |
+
label_type = 'tracking'
|
28 |
+
para_dict = data_parameter_dict(data_filename, label_type)
|
29 |
+
pprint(para_dict)
|
30 |
+
vidarSpikes = SpikeStream(**para_dict)
|
31 |
+
|
32 |
+
# block_len = 2000
|
33 |
+
# spikes = vidarSpikes.get_block_spikes(begin_idx=0, block_len=block_len)
|
34 |
+
spikes = vidarSpikes.get_spike_matrix()
|
35 |
+
pprint(spikes.shape)
|
36 |
+
|
37 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
+
|
39 |
+
calibration_time = 150
|
40 |
+
filename = path.split_path_into_pieces(data_filename)
|
41 |
+
result_filename = filename[-1] + '_snn.txt'
|
42 |
+
if not os.path.exists('results'):
|
43 |
+
os.makedirs('results')
|
44 |
+
tracking_file = os.path.join('results', result_filename)
|
45 |
+
if os.path.exists(tracking_file):
|
46 |
+
os.remove(tracking_file)
|
47 |
+
|
48 |
+
# stp_params = {'filterThr': 0.12, # filter threshold
|
49 |
+
# 'voltageMin': -10,
|
50 |
+
# 'lifThr': 3}
|
51 |
+
spike_tracker = SNNTracker(para_dict.get('spike_h'), para_dict.get('spike_w'), device, attention_size=15)
|
52 |
+
spike_tracker.object_cluster.K2 = 4
|
53 |
+
# total_spikes = spikes
|
54 |
+
|
55 |
+
# using stp filter to filter out static spikes
|
56 |
+
spike_tracker.calibrate_motion(spikes, calibration_time)
|
57 |
+
# start tracking
|
58 |
+
track_videoName = tracking_file.replace('txt', 'avi')
|
59 |
+
mov = cv2.VideoWriter(track_videoName, cv2.VideoWriter_fourcc(*'MJPG'), 30, (para_dict.get('spike_w'), para_dict.get('spike_h')))
|
60 |
+
spike_tracker.get_results(spikes[calibration_time:], tracking_file, mov, save_video=True)
|
61 |
+
|
62 |
+
data_name = test_scene[scene_idx]
|
63 |
+
trajectories_filename = os.path.join('results', data_name + '_py.json')
|
64 |
+
visTraj_filename = os.path.join('results', data_name + '.png')
|
65 |
+
|
66 |
+
spike_tracker.save_trajectory('results', data_name)
|
67 |
+
vis_trajectory(trajectories_filename, visTraj_filename, **para_dict)
|
68 |
+
# measure the multi-object tracking performance
|
69 |
+
# metrics = TrackingMetrics(tracking_file, **para_dict)
|
70 |
+
# metrics.get_results()
|
71 |
+
#
|
72 |
+
# block_len = total_spikes.shape[0]
|
73 |
+
mov.release()
|
74 |
+
cv2.destroyAllWindows()
|
75 |
+
# # visualize the tracking results to a video
|
76 |
+
# video_filename = os.path.join('results', filename[-1] + '_mot.avi')
|
77 |
+
# obtain_mot_video(spike_tracker.filterd_spikes, video_filename, tracking_file, **para_dict)
|
78 |
+
# obtain_detection_video(total_spikes, video_filename, tracking_file, evaluate_seq_len=evaluate_seq_len, **para_dict)
|
snnTracker/utils.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import threading
|
5 |
+
import cv2
|
6 |
+
import json
|
7 |
+
|
8 |
+
# import matplotlib
|
9 |
+
# matplotlib.use('TkAgg')
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from mpl_toolkits.mplot3d import Axes3D
|
12 |
+
from matplotlib.pyplot import MultipleLocator
|
13 |
+
|
14 |
+
|
15 |
+
class dataReader(threading.Thread):
|
16 |
+
def __init__(self, file_reader, device, q, is_dat=True, is_npy=False, filedir=None):
|
17 |
+
super(dataReader, self).__init__()
|
18 |
+
self.file_reader = file_reader
|
19 |
+
self.device = device
|
20 |
+
self.q = q
|
21 |
+
self.is_dat = is_dat
|
22 |
+
self.is_npy = is_npy
|
23 |
+
self.filedir = filedir
|
24 |
+
self.stream = torch.cuda.Stream()
|
25 |
+
|
26 |
+
def run(self):
|
27 |
+
with torch.cuda.stream(self.stream):
|
28 |
+
for t in range(tnum):
|
29 |
+
if self.is_dat:
|
30 |
+
ibuffer = self.file_reader.read(int(ivs_w * ivs_h / 8))
|
31 |
+
a = bin(int.from_bytes(ibuffer, byteorder=sys.byteorder))
|
32 |
+
a = a[2:].zfill(ivs_w * ivs_h)
|
33 |
+
|
34 |
+
a = list(a)
|
35 |
+
a = np.array(a, dtype=np.byte)
|
36 |
+
a = np.reshape(a, [ivs_h, ivs_w])
|
37 |
+
if ivs_h == 600:
|
38 |
+
a = np.flip(a, 0)
|
39 |
+
if ivs_h == 250:
|
40 |
+
a = np.flip(a, 1)
|
41 |
+
input_spk = torch.from_numpy(a != 0).to(device)
|
42 |
+
elif self.is_npy:
|
43 |
+
npy_filename = self.filedir + str(t + 442) + '.npy'
|
44 |
+
tmp_data = np.load(npy_filename)
|
45 |
+
superResolution_rate = tmp_data.shape[2]
|
46 |
+
for i_data in range(superResolution_rate):
|
47 |
+
tmp_spk = tmp_data[:, :, i_data]
|
48 |
+
input_spk = torch.from_numpy(tmp_spk).to(device)
|
49 |
+
self.q.put(input_spk)
|
50 |
+
|
51 |
+
else:
|
52 |
+
# img_filename = self.filedir + str(t + 4200) + '.png'
|
53 |
+
img_filename = self.filedir + 'spike_' + str(t + 1) + '.png'
|
54 |
+
# print('reading %d frames' % (t+1))
|
55 |
+
# print('reading %d frames' % (t+5000))
|
56 |
+
a = cv2.imread(img_filename)
|
57 |
+
a = cv2.cvtColor(a, cv2.COLOR_BGR2GRAY)
|
58 |
+
a = a / 255
|
59 |
+
a = np.array(a, dtype=np.byte)
|
60 |
+
input_spk = torch.from_numpy(a != 0).to(device)
|
61 |
+
|
62 |
+
self.q.put(input_spk)
|
63 |
+
|
64 |
+
|
65 |
+
# obtain 2D gaussian filter
|
66 |
+
def get_kernel(filter_size, sigma):
|
67 |
+
assert (filter_size + 1) % 2 == 0, '2D filter size must be odd number!'
|
68 |
+
g = np.zeros((filter_size, filter_size), dtype=np.float32)
|
69 |
+
half_width = int((filter_size - 1) / 2)
|
70 |
+
# center location
|
71 |
+
|
72 |
+
xc = (filter_size + 1) / 2
|
73 |
+
yc = (filter_size + 1) / 2
|
74 |
+
for i in range(-half_width, half_width + 1, 1):
|
75 |
+
for j in range(-half_width, half_width + 1, 1):
|
76 |
+
x = int(xc + i)
|
77 |
+
y = int(yc + j)
|
78 |
+
g[y - 1, x - 1] = np.exp(- (i ** 2 + j ** 2) / 2 / sigma / sigma)
|
79 |
+
|
80 |
+
g = (g - g.min()) / (g.max() - g.min())
|
81 |
+
return g
|
82 |
+
|
83 |
+
|
84 |
+
def get_transform_matrix(ori, speed):
|
85 |
+
ori_num = len(ori)
|
86 |
+
speed_num = len(speed)
|
87 |
+
transform_matrix = torch.zeros(ori_num * speed_num, 2, 3)
|
88 |
+
cnt = 0
|
89 |
+
for iOri in range(ori_num):
|
90 |
+
for iSpeed in range(speed_num):
|
91 |
+
transform_matrix[cnt, 0, 0] = 1
|
92 |
+
transform_matrix[cnt, 1, 1] = 1
|
93 |
+
|
94 |
+
transform_matrix[cnt, 0, 2] = - float(ori[iOri, 1] * speed[iSpeed] / ivs_w)
|
95 |
+
transform_matrix[cnt, 1, 2] = - float(ori[iOri, 0] * speed[iSpeed] / ivs_h)
|
96 |
+
|
97 |
+
cnt += 1
|
98 |
+
|
99 |
+
transform_matrix = transform_matrix.to(device)
|
100 |
+
return transform_matrix
|
101 |
+
|
102 |
+
|
103 |
+
def get_transform_matrix_new(ori, speed, dvs_w, dvs_h, device):
|
104 |
+
ori_num = len(ori)
|
105 |
+
speed_num = len(speed)
|
106 |
+
transform_matrix = torch.zeros(ori_num * speed_num, 2, 3)
|
107 |
+
cnt = 0
|
108 |
+
for iOri in range(ori_num):
|
109 |
+
for iSpeed in range(speed_num):
|
110 |
+
transform_matrix[cnt, 0, 0] = 1
|
111 |
+
transform_matrix[cnt, 1, 1] = 1
|
112 |
+
|
113 |
+
transform_matrix[cnt, 0, 2] = - float(ori[iOri, 1] * speed[iSpeed] / dvs_w)
|
114 |
+
transform_matrix[cnt, 1, 2] = - float(ori[iOri, 0] * speed[iSpeed] / dvs_h)
|
115 |
+
|
116 |
+
cnt += 1
|
117 |
+
|
118 |
+
transform_matrix = transform_matrix.to(device)
|
119 |
+
return transform_matrix
|
120 |
+
|
121 |
+
|
122 |
+
# monitor the inference process
|
123 |
+
def visualize_img(gray_img, tag, curT):
|
124 |
+
gray_img = gray_img.float32()
|
125 |
+
img = torch.unsqueeze(gray_img, 0)
|
126 |
+
logger.add_image(tag, img, global_step=curT)
|
127 |
+
|
128 |
+
|
129 |
+
def visualize_images(images, tag, curT):
|
130 |
+
if images.shape[0] < 1:
|
131 |
+
return
|
132 |
+
images = torch.squeeze(images)
|
133 |
+
img_num = images.shape[-1]
|
134 |
+
for iImg in range(img_num):
|
135 |
+
tmp_img = images[:, :, iImg]
|
136 |
+
tmp_img = torch.squeeze(tmp_img)
|
137 |
+
tmp_img = torch.unsqueeze(tmp_img, 0)
|
138 |
+
logger.add_image(tag + str(iImg), tmp_img, global_step=curT)
|
139 |
+
|
140 |
+
|
141 |
+
def visualize_weights(weights, tag, curT):
|
142 |
+
if weights.shape[0] < 1:
|
143 |
+
return
|
144 |
+
weights = torch.squeeze(weights)
|
145 |
+
weights_num = weights.shape[0]
|
146 |
+
input_size = weights.shape[1]
|
147 |
+
stim_size = int(np.sqrt(input_size))
|
148 |
+
for iw in range(weights_num):
|
149 |
+
tmp_w = weights[iw, :]
|
150 |
+
tmp_w = torch.squeeze(tmp_w)
|
151 |
+
tmp_w = (tmp_w - torch.min(tmp_w)) / (torch.max(tmp_w) - torch.min(tmp_w))
|
152 |
+
tmp_w = torch.reshape(tmp_w, (stim_size, stim_size))
|
153 |
+
tmp_w = torch.unsqueeze(tmp_w, 0)
|
154 |
+
logger.add_image(tag + str(iw), tmp_w, global_step=curT)
|
155 |
+
|
156 |
+
|
157 |
+
class NumpyEncoder(json.JSONEncoder):
|
158 |
+
def default(self, obj):
|
159 |
+
if isinstance(obj, np.ndarray):
|
160 |
+
return obj.tolist()
|
161 |
+
return json.JSONEncoder.default(self, obj)
|
162 |
+
|
163 |
+
|
164 |
+
def vis_trajectory(json_file, filename, **dataDict):
|
165 |
+
spike_h = dataDict.get('spike_h')
|
166 |
+
spike_w = dataDict.get('spike_w')
|
167 |
+
traj_dict = []
|
168 |
+
with open(json_file, 'r') as f:
|
169 |
+
for line in f.readlines():
|
170 |
+
traj_dict.append(json.loads(line))
|
171 |
+
|
172 |
+
num_traj = len(traj_dict)
|
173 |
+
|
174 |
+
fig = plt.figure(figsize=[10, 6])
|
175 |
+
ax = fig.add_subplot(111, projection='3d')
|
176 |
+
min_t = 1000
|
177 |
+
max_t = 0
|
178 |
+
|
179 |
+
for tmp_traj in traj_dict:
|
180 |
+
tmp_t = np.array(tmp_traj['t'])
|
181 |
+
if np.min(tmp_t) < min_t:
|
182 |
+
min_t = np.min(tmp_t)
|
183 |
+
if np.max(tmp_t) > max_t:
|
184 |
+
max_t = np.max(tmp_t)
|
185 |
+
|
186 |
+
tmp_x = spike_w - np.array(tmp_traj['x'])
|
187 |
+
tmp_y = np.array(tmp_traj['y'])
|
188 |
+
tmp_color = np.array(tmp_traj['color']) / 255.
|
189 |
+
ax.plot(tmp_t, tmp_x, tmp_y, color=tmp_color, linewidth=2, label='traj ' + str(tmp_traj['id']))
|
190 |
+
|
191 |
+
ax.legend(loc='best', bbox_to_anchor=(0.7, 0., 0.4, 0.8))
|
192 |
+
zoom = [2.2, 0.8, 0.5, 1]
|
193 |
+
ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([zoom[0], zoom[1], zoom[2], zoom[3]]))
|
194 |
+
ax.set_xlim(min_t, max_t)
|
195 |
+
ax.set_ylim(0, spike_w)
|
196 |
+
ax.set_zlim(0, spike_h)
|
197 |
+
|
198 |
+
ax.set_xlabel('time', fontsize=15)
|
199 |
+
ax.set_ylabel('width', fontsize=15)
|
200 |
+
ax.set_zlabel('height', fontsize=15)
|
201 |
+
|
202 |
+
ax.view_init(elev=16, azim=135)
|
203 |
+
ax.yaxis.set_major_locator(MultipleLocator(100))
|
204 |
+
fig.subplots_adjust(top=1., bottom=0., left=0.2, right=1.)
|
205 |
+
# fig.tight_layout()
|
206 |
+
plt.show()
|
207 |
+
plt.savefig(filename, dpi=500, transparent=True)
|
snnTracker/visualization/get_image.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2023/8/20 16:06
|
3 |
+
# @Author : Yajing Zheng
|
4 |
+
# @Email: [email protected]
|
5 |
+
# @File : get_image.py
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib
|
8 |
+
matplotlib.use('Agg')
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import json
|
11 |
+
from mpl_toolkits.mplot3d import Axes3D
|
12 |
+
from matplotlib.pyplot import MultipleLocator
|
13 |
+
from matplotlib.patches import Rectangle
|
14 |
+
import torch
|
15 |
+
import copy
|
16 |
+
|
17 |
+
def get_spike_raster(data):
|
18 |
+
num_neuron, timesteps = data.shape
|
19 |
+
colors = [f'C{i}' for i in range(num_neuron)]
|
20 |
+
# set different line properties for each set of positions
|
21 |
+
# note that some overlap
|
22 |
+
lineoffsets1 = np.array(range(1, num_neuron*2+1, 2))
|
23 |
+
linelengthts1 = np.ones((num_neuron, )) * 1.5
|
24 |
+
|
25 |
+
plt.figure(figsize=(8, 6))
|
26 |
+
plt.eventplot(data, colors=colors, lineoffsets=lineoffsets1, linelengths=linelengthts1)
|
27 |
+
return plt.gcf()
|
28 |
+
|
29 |
+
|
30 |
+
def get_heatmap_handle(data, marker=None, bounding_box=None):
|
31 |
+
|
32 |
+
if torch.is_tensor(data):
|
33 |
+
data = copy.deepcopy(data.cpu().detach().numpy())
|
34 |
+
|
35 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
36 |
+
h, w = data.shape
|
37 |
+
if marker is not None:
|
38 |
+
num_points = marker.shape[1]
|
39 |
+
colors = [f'C{i}' for i in range(num_points)]
|
40 |
+
for i_point in range(num_points):
|
41 |
+
ax.plot(marker[1, i_point], h-marker[0, i_point], 'o', color=colors[i_point], markersize=10)
|
42 |
+
ax.annotate('P{}'.format(i_point), (marker[1, i_point], h-marker[0, i_point]))
|
43 |
+
|
44 |
+
if bounding_box is not None:
|
45 |
+
for i_box, bbox in enumerate(bounding_box):
|
46 |
+
ax.add_patch(Rectangle((bbox[1], bbox[0]), bbox[3]-bbox[1], bbox[2] - bbox[0],
|
47 |
+
edgecolor='red', facecolor='none', lw=2))
|
48 |
+
|
49 |
+
ax.imshow(data, cmap='Blues', interpolation='nearest')
|
50 |
+
|
51 |
+
# plt.colorbar()
|
52 |
+
plt.axis('off') # 可选,关闭坐标轴
|
53 |
+
plt.title('Heatmap')
|
54 |
+
|
55 |
+
return plt.gcf()
|
56 |
+
|
57 |
+
|
58 |
+
def get_histogram_handle(data, marker=None, bounding_box=None):
|
59 |
+
|
60 |
+
if torch.is_tensor(data):
|
61 |
+
data = copy.deepcopy(data.cpu().detach().numpy())
|
62 |
+
|
63 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
64 |
+
h, w = data.shape
|
65 |
+
ax.hist(data.reshape((-1, 1)), bins=20)
|
66 |
+
|
67 |
+
# plt.colorbar()
|
68 |
+
# plt.axis('off') # 可选,关闭坐标轴
|
69 |
+
plt.title('Heatmap')
|
70 |
+
|
71 |
+
return plt.gcf()
|
72 |
+
def vis_trajectory(box_file, json_file, filename, **dataDict):
|
73 |
+
|
74 |
+
spike_h = dataDict.get('spike_h')
|
75 |
+
spike_w = dataDict.get('spike_w')
|
76 |
+
traj_dict = []
|
77 |
+
with open(json_file, 'r') as f:
|
78 |
+
for line in f.readlines():
|
79 |
+
traj_dict.append(json.loads(line))
|
80 |
+
|
81 |
+
box_file = open(box_file, 'r')
|
82 |
+
result_lines = box_file.readlines()
|
83 |
+
num_traj = len(traj_dict)
|
84 |
+
|
85 |
+
fig = plt.figure(figsize=[10, 6])
|
86 |
+
ax = fig.add_subplot(111, projection='3d')
|
87 |
+
min_t = 1000
|
88 |
+
max_t = 0
|
89 |
+
|
90 |
+
for tmp_traj in traj_dict:
|
91 |
+
tmp_t = np.array(tmp_traj['t'])
|
92 |
+
if np.min(tmp_t) < min_t:
|
93 |
+
min_t = np.min(tmp_t)
|
94 |
+
if np.max(tmp_t) > max_t:
|
95 |
+
max_t = np.max(tmp_t)
|
96 |
+
|
97 |
+
tmp_x = spike_w - np.array(tmp_traj['x'])
|
98 |
+
tmp_y = np.array(tmp_traj['y'])
|
99 |
+
tmp_color = np.array(tmp_traj['color']) / 255.
|
100 |
+
ax.plot(tmp_t, tmp_x, tmp_y, color=tmp_color, linewidth=2, label='traj ' + str(tmp_traj['id']))
|
101 |
+
|
102 |
+
ax.legend(loc='best', bbox_to_anchor=(0.7, 0., 0.4, 0.8))
|
103 |
+
zoom = [2.2, 0.8, 0.5, 1]
|
104 |
+
ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([zoom[0], zoom[1], zoom[2], zoom[3]]))
|
105 |
+
ax.set_xlim(min_t, max_t)
|
106 |
+
ax.set_ylim(0, spike_w)
|
107 |
+
ax.set_zlim(0, spike_h)
|
108 |
+
|
109 |
+
ax.set_xlabel('time', fontsize=15)
|
110 |
+
ax.set_ylabel('width', fontsize=15)
|
111 |
+
ax.set_zlabel('height', fontsize=15)
|
112 |
+
|
113 |
+
ax.view_init(elev=16, azim=135)
|
114 |
+
# ax.view_init(elev=2, azim=27)
|
115 |
+
ax.yaxis.set_major_locator(MultipleLocator(100))
|
116 |
+
fig.subplots_adjust(top=1., bottom=0., left=0.2, right=1.)
|
117 |
+
# fig.tight_layout()
|
118 |
+
# plt.savefig(filename, dpi=500, transparent=True)
|
119 |
+
# filename = filename.replace('png', 'eps')
|
120 |
+
# plt.savefig(filename, dpi=500, transparent=True)
|
121 |
+
plt.show()
|
snnTracker/visualization/get_video.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2022/6/12 15:21
|
3 |
+
# @Author : Yajing Zheng
|
4 |
+
# @File : visualize.py
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib
|
8 |
+
matplotlib.use('Agg')
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
|
11 |
+
import matplotlib.cm as cm
|
12 |
+
import matplotlib.animation as animation
|
13 |
+
|
14 |
+
def obtain_spike_video(spikes, video_filename, **dataDict):
|
15 |
+
spike_h = dataDict.get('spike_h')
|
16 |
+
spike_w = dataDict.get('spike_w')
|
17 |
+
timestamps = spikes.shape[0]
|
18 |
+
|
19 |
+
mov = cv2.VideoWriter(video_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (spike_w, spike_h))
|
20 |
+
|
21 |
+
for iSpk in range(timestamps):
|
22 |
+
tmpSpk = spikes[iSpk, :, :] * 255
|
23 |
+
tmpSpk = cv2.cvtColor(tmpSpk.astype(np.uint8), cv2.COLOR_GRAY2BGR)
|
24 |
+
mov.write(tmpSpk)
|
25 |
+
|
26 |
+
mov.release()
|
27 |
+
|
28 |
+
|
29 |
+
def obtain_reconstruction_video(images, video_filename, **dataDict):
|
30 |
+
spike_h = dataDict.get('spike_h')
|
31 |
+
spike_w = dataDict.get('spike_w')
|
32 |
+
|
33 |
+
img_num = images.shape[0]
|
34 |
+
mov = cv2.VideoWriter(video_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (spike_w, spike_h))
|
35 |
+
for iImg in range(img_num):
|
36 |
+
tmp_img = images[iImg, :, :]
|
37 |
+
tmp_img = cv2.cvtColor(tmp_img, cv2.COLOR_GRAY2BGR)
|
38 |
+
mov.write(tmp_img)
|
39 |
+
|
40 |
+
mov.release()
|
41 |
+
|
42 |
+
|
43 |
+
def obtain_mot_video(spikes, video_filename, res_filepath, **dataDict):
|
44 |
+
spike_h = dataDict.get('spike_h')
|
45 |
+
spike_w = dataDict.get('spike_w')
|
46 |
+
|
47 |
+
gt_file = dataDict.get('labeled_data_dir')
|
48 |
+
gt_boxes = {}
|
49 |
+
if gt_file is not None:
|
50 |
+
gt_f = open(gt_file, 'r')
|
51 |
+
gt_lines = gt_f.readlines()
|
52 |
+
for line in gt_lines:
|
53 |
+
gt_term = line.split(',')
|
54 |
+
time_step = gt_term[0]
|
55 |
+
box_id = gt_term[1]
|
56 |
+
x = float(gt_term[2])
|
57 |
+
y = float(gt_term[3])
|
58 |
+
w = float(gt_term[4])
|
59 |
+
h = float(gt_term[5])
|
60 |
+
|
61 |
+
if str(time_step) not in gt_boxes:
|
62 |
+
gt_boxes[str(time_step)] = []
|
63 |
+
bbox = [box_id, x, y, w, h]
|
64 |
+
gt_boxes[str(time_step)].append(bbox)
|
65 |
+
|
66 |
+
gt_f.close()
|
67 |
+
|
68 |
+
result_file = res_filepath
|
69 |
+
test_boxes = {}
|
70 |
+
result_f = open(result_file, 'r')
|
71 |
+
result_lines = result_f.readlines()
|
72 |
+
color_dict = {}
|
73 |
+
|
74 |
+
for line in result_lines:
|
75 |
+
res_box = line.split(',')
|
76 |
+
time_step = res_box[0]
|
77 |
+
track_id = res_box[1]
|
78 |
+
if track_id not in color_dict.keys():
|
79 |
+
colors = (np.random.rand(1, 3) * 255).astype(np.uint8)
|
80 |
+
color_dict[track_id] = np.squeeze(colors)
|
81 |
+
|
82 |
+
x = float(res_box[2])
|
83 |
+
y = float(res_box[3])
|
84 |
+
w = float(res_box[4])
|
85 |
+
h = float(res_box[5])
|
86 |
+
|
87 |
+
if str(time_step) not in test_boxes:
|
88 |
+
test_boxes[str(time_step)] = []
|
89 |
+
|
90 |
+
test_box = [track_id, x, y, w, h]
|
91 |
+
test_boxes[str(time_step)].append(test_box)
|
92 |
+
|
93 |
+
result_f.close()
|
94 |
+
|
95 |
+
mov = cv2.VideoWriter(video_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (spike_w, spike_h))
|
96 |
+
|
97 |
+
timestamps = spikes.shape[0]
|
98 |
+
for t in range(151, timestamps):
|
99 |
+
# for t in range(160, 1000):
|
100 |
+
tmp_ivs = spikes[t, :, :] * 255
|
101 |
+
tmp_ivs = cv2.cvtColor(tmp_ivs.astype(np.uint8), cv2.COLOR_GRAY2BGR)
|
102 |
+
|
103 |
+
if len(gt_boxes) > 0:
|
104 |
+
if str(t) in gt_boxes:
|
105 |
+
gts = gt_boxes[str(t)]
|
106 |
+
gt_num = len(gts)
|
107 |
+
for i in range(gt_num):
|
108 |
+
box = gts[i]
|
109 |
+
box_id = box[0]
|
110 |
+
cv2.rectangle(tmp_ivs, (int(box[2]), int(box[1])),
|
111 |
+
(int(box[2] + box[4]), int(box[1] + box[3])),
|
112 |
+
(int(255), int(255), int(255)), 2)
|
113 |
+
|
114 |
+
if str(t) in test_boxes:
|
115 |
+
test = test_boxes[str(t)]
|
116 |
+
test_num = len(test)
|
117 |
+
for i in range(test_num):
|
118 |
+
box = test[i]
|
119 |
+
box_id = box[0]
|
120 |
+
colors = color_dict[box_id]
|
121 |
+
cv2.rectangle(tmp_ivs, (int(box[2]), int(box[1])),
|
122 |
+
(int(box[2] + box[4]), int(box[1] + box[3])),
|
123 |
+
(int(colors[0]), int(colors[1]), int(colors[2])), 2)
|
124 |
+
|
125 |
+
mov.write(tmp_ivs)
|
126 |
+
|
127 |
+
mov.release()
|
128 |
+
|
129 |
+
|
130 |
+
def obtain_detection_video(spikes, video_filename, res_filepath, evaluate_seq_len, begin_idx=0, **dataDict):
|
131 |
+
spike_h = dataDict.get('spike_h')
|
132 |
+
spike_w = dataDict.get('spike_w')
|
133 |
+
|
134 |
+
gt_file = dataDict.get('labeled_data_dir')
|
135 |
+
gt_boxes = {}
|
136 |
+
if gt_file is not None:
|
137 |
+
start_idx = begin_idx
|
138 |
+
end_idx = begin_idx + evaluate_seq_len
|
139 |
+
for seq_no in range(start_idx, end_idx):
|
140 |
+
gt_filename = gt_file[seq_no]
|
141 |
+
gt_f = open(gt_filename, 'r')
|
142 |
+
|
143 |
+
gt_lines = gt_f.readlines()
|
144 |
+
for line in gt_lines:
|
145 |
+
tmp_box = line.split(',')
|
146 |
+
|
147 |
+
x = float(tmp_box[0])
|
148 |
+
y = float(tmp_box[1])
|
149 |
+
w = float(tmp_box[2])
|
150 |
+
h = float(tmp_box[3])
|
151 |
+
box_id = int(0)
|
152 |
+
|
153 |
+
if str(seq_no) not in gt_boxes:
|
154 |
+
gt_boxes[str(seq_no)] = []
|
155 |
+
bbox = [box_id, x, y, w, h]
|
156 |
+
gt_boxes[str(seq_no)].append(bbox)
|
157 |
+
|
158 |
+
gt_f.close()
|
159 |
+
|
160 |
+
result_file = res_filepath
|
161 |
+
test_boxes = {}
|
162 |
+
result_f = open(result_file, 'r')
|
163 |
+
result_lines = result_f.readlines()
|
164 |
+
color_dict = {}
|
165 |
+
|
166 |
+
for line in result_lines:
|
167 |
+
res_box = line.split(',')
|
168 |
+
time_step = res_box[0]
|
169 |
+
track_id = res_box[1]
|
170 |
+
if track_id not in color_dict.keys():
|
171 |
+
colors = (np.random.rand(1, 3) * 255).astype(np.uint8)
|
172 |
+
color_dict[track_id] = np.squeeze(colors)
|
173 |
+
|
174 |
+
x = float(res_box[2])
|
175 |
+
y = float(res_box[3])
|
176 |
+
w = float(res_box[4])
|
177 |
+
h = float(res_box[5])
|
178 |
+
|
179 |
+
if str(time_step) not in test_boxes:
|
180 |
+
test_boxes[str(time_step)] = []
|
181 |
+
|
182 |
+
test_box = [track_id, x, y, w, h]
|
183 |
+
test_boxes[str(time_step)].append(test_box)
|
184 |
+
|
185 |
+
result_f.close()
|
186 |
+
|
187 |
+
mov = cv2.VideoWriter(video_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (spike_w, spike_h))
|
188 |
+
|
189 |
+
block_len = spikes.shape[0]
|
190 |
+
# gt_intv = int(block_len/evaluate_seq_len)
|
191 |
+
gt_intv = 400
|
192 |
+
|
193 |
+
# for t in range(150, block_len):
|
194 |
+
for i_gt in range(start_idx+1, end_idx):
|
195 |
+
t = i_gt * gt_intv + int(gt_intv/2)
|
196 |
+
tmp_ivs = spikes[t, :, :] * 255
|
197 |
+
tmp_ivs = cv2.cvtColor(tmp_ivs.astype(np.uint8), cv2.COLOR_GRAY2BGR)
|
198 |
+
|
199 |
+
if len(gt_boxes) > 0:
|
200 |
+
gts = gt_boxes[str(i_gt)]
|
201 |
+
gt_num = len(gts)
|
202 |
+
for i in range(gt_num):
|
203 |
+
box = gts[i]
|
204 |
+
cv2.rectangle(tmp_ivs, (int(spike_w - box[1]), int(box[2])),
|
205 |
+
(int(spike_w - box[1] - box[3]), int(box[2] + box[4])),
|
206 |
+
(int(255), int(255), int(255)), 2)
|
207 |
+
|
208 |
+
if str(t) in test_boxes:
|
209 |
+
test = test_boxes[str(t)]
|
210 |
+
test_num = len(test)
|
211 |
+
for i in range(test_num):
|
212 |
+
box = test[i]
|
213 |
+
box_id = box[0]
|
214 |
+
colors = color_dict[box_id]
|
215 |
+
cv2.rectangle(tmp_ivs, (int(box[2]), int(box[1])),
|
216 |
+
(int(box[2] + box[4]), int(box[1] + box[3])),
|
217 |
+
(int(colors[0]), int(colors[1]), int(colors[2])), 2)
|
218 |
+
|
219 |
+
mov.write(tmp_ivs)
|
220 |
+
|
221 |
+
mov.release()
|
222 |
+
|
223 |
+
def get_heatVideo(results, video_filename):
|
224 |
+
results = np.array(results)
|
225 |
+
frame_num = results.shape[0]
|
226 |
+
frames = []
|
227 |
+
|
228 |
+
fig = plt.figure()
|
229 |
+
for i in range(frame_num):
|
230 |
+
tmp_res = results[i]
|
231 |
+
# frames.append([plt.imshow(tmp_res, cmap=cm.Greys_r, animated=True)])
|
232 |
+
frames.append([plt.imshow(tmp_res, cmap=cm.Blues, animated=True)])
|
233 |
+
|
234 |
+
ani = animation.ArtistAnimation(fig, frames, interval=50, blit=True,
|
235 |
+
repeat_delay=1000)
|
236 |
+
|
237 |
+
# change the path to where you save ffmpeg
|
238 |
+
plt.rcParams['animation.ffmpeg_path'] = 'F:\\ffmpeg-N-99818-g993429cfb4-win64-gpl-shared-vulkan\\bin\\ffmpeg.exe'
|
239 |
+
FFwrite = animation.FFMpegWriter(fps=30, extra_args=['-vcodec', 'libx264'])
|
240 |
+
ani.save(video_filename, writer=FFwrite)
|
241 |
+
plt.show()
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
|
snnTracker/visualization/optical_flow_visualization.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2022/7/21
|
3 |
+
# @Author : Rui Zhao
|
4 |
+
# @File : optical_flow_visualization.py
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
import math
|
10 |
+
|
11 |
+
|
12 |
+
#################### Interface ####################
|
13 |
+
def flow_visualization(flow, mode='normal', use_cv2=True):
|
14 |
+
if mode == 'normal':
|
15 |
+
flow_vis = flow_to_image(flow_uv=flow, convert_to_bgr=use_cv2)
|
16 |
+
elif mode == 'scflow':
|
17 |
+
flow_vis = flow_to_img_scflow(flow_uv=flow)
|
18 |
+
if not use_cv2:
|
19 |
+
flow_vis = cv2.cvtColor(flow_vis, cv2.COLOR_BGR2RGB)
|
20 |
+
elif mode == 'evflow':
|
21 |
+
flow_vis = flow_viz_np(flow_x=flow[:,:,0], flow_y=flow[:,:,1])
|
22 |
+
|
23 |
+
return flow_vis
|
24 |
+
|
25 |
+
|
26 |
+
def vis_color_map(use_cv2=True):
|
27 |
+
u = np.linspace(-100, 99, 200)
|
28 |
+
v = np.linspace(-100, 99, 200)
|
29 |
+
xx, yy = np.meshgrid(u, v)
|
30 |
+
flow = np.concatenate((xx[:,:,None], yy[:,:,None]), axis=2)
|
31 |
+
map_normal = flow_visualization(flow=flow, mode='normal', use_cv2=use_cv2)
|
32 |
+
map_scflow = flow_visualization(flow=flow, mode='scflow', use_cv2=use_cv2)
|
33 |
+
map_evflow = flow_visualization(flow=flow, mode='evflow', use_cv2=use_cv2)
|
34 |
+
return [map_normal, map_scflow, map_evflow]
|
35 |
+
|
36 |
+
|
37 |
+
def make_colorwheel():
|
38 |
+
"""
|
39 |
+
Generates a color wheel for optical flow visualization as presented in:
|
40 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
41 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
42 |
+
|
43 |
+
Code follows the original C++ source code of Daniel Scharstein.
|
44 |
+
Code follows the the Matlab source code of Deqing Sun.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
np.ndarray: Color wheel
|
48 |
+
"""
|
49 |
+
|
50 |
+
RY = 15
|
51 |
+
YG = 6
|
52 |
+
GC = 4
|
53 |
+
CB = 11
|
54 |
+
BM = 13
|
55 |
+
MR = 6
|
56 |
+
|
57 |
+
ncols = RY + YG + GC + CB + BM + MR
|
58 |
+
colorwheel = np.zeros((ncols, 3))
|
59 |
+
col = 0
|
60 |
+
|
61 |
+
# RY
|
62 |
+
colorwheel[0:RY, 0] = 255
|
63 |
+
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
64 |
+
col = col+RY
|
65 |
+
# YG
|
66 |
+
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
67 |
+
colorwheel[col:col+YG, 1] = 255
|
68 |
+
col = col+YG
|
69 |
+
# GC
|
70 |
+
colorwheel[col:col+GC, 1] = 255
|
71 |
+
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
72 |
+
col = col+GC
|
73 |
+
# CB
|
74 |
+
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
75 |
+
colorwheel[col:col+CB, 2] = 255
|
76 |
+
col = col+CB
|
77 |
+
# BM
|
78 |
+
colorwheel[col:col+BM, 2] = 255
|
79 |
+
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
80 |
+
col = col+BM
|
81 |
+
# MR
|
82 |
+
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
83 |
+
colorwheel[col:col+MR, 0] = 255
|
84 |
+
return colorwheel
|
85 |
+
|
86 |
+
|
87 |
+
#################### Normal Version ####################
|
88 |
+
"""
|
89 |
+
From https://github.com/princeton-vl/RAFT/blob/master/core/utils/flow_viz.py
|
90 |
+
"""
|
91 |
+
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
92 |
+
"""
|
93 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
94 |
+
According to the C++ source code of Daniel Scharstein
|
95 |
+
According to the Matlab source code of Deqing Sun
|
96 |
+
Args:
|
97 |
+
u (np.ndarray): Input horizontal flow of shape [H,W]
|
98 |
+
v (np.ndarray): Input vertical flow of shape [H,W]
|
99 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
100 |
+
Returns:
|
101 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
102 |
+
"""
|
103 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
104 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
105 |
+
ncols = colorwheel.shape[0]
|
106 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
107 |
+
a = np.arctan2(-v, -u)/np.pi
|
108 |
+
fk = (a+1) / 2*(ncols-1)
|
109 |
+
k0 = np.floor(fk).astype(np.int32)
|
110 |
+
k1 = k0 + 1
|
111 |
+
k1[k1 == ncols] = 0
|
112 |
+
f = fk - k0
|
113 |
+
for i in range(colorwheel.shape[1]):
|
114 |
+
tmp = colorwheel[:,i]
|
115 |
+
col0 = tmp[k0] / 255.0
|
116 |
+
col1 = tmp[k1] / 255.0
|
117 |
+
col = (1-f)*col0 + f*col1
|
118 |
+
idx = (rad <= 1)
|
119 |
+
col[idx] = 1 - rad[idx] * (1-col[idx])
|
120 |
+
col[~idx] = col[~idx] * 0.75 # out of range
|
121 |
+
# Note the 2-i => BGR instead of RGB
|
122 |
+
ch_idx = 2-i if convert_to_bgr else i
|
123 |
+
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
124 |
+
return flow_image
|
125 |
+
|
126 |
+
|
127 |
+
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
128 |
+
"""
|
129 |
+
Expects a two dimensional flow image of shape.
|
130 |
+
Args:
|
131 |
+
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
132 |
+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
133 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
134 |
+
Returns:
|
135 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
136 |
+
"""
|
137 |
+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
138 |
+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
139 |
+
if clip_flow is not None:
|
140 |
+
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
141 |
+
u = flow_uv[:,:,0]
|
142 |
+
v = flow_uv[:,:,1]
|
143 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
144 |
+
rad_max = np.max(rad)
|
145 |
+
epsilon = 1e-5
|
146 |
+
u = u / (rad_max + epsilon)
|
147 |
+
v = v / (rad_max + epsilon)
|
148 |
+
return flow_uv_to_colors(u, v, convert_to_bgr)
|
149 |
+
|
150 |
+
|
151 |
+
#################### SCFlow Version ####################
|
152 |
+
def flow_uv_to_colors_scflow(u, v, convert_to_bgr=False):
|
153 |
+
"""
|
154 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
155 |
+
|
156 |
+
According to the C++ source code of Daniel Scharstein
|
157 |
+
According to the Matlab source code of Deqing Sun
|
158 |
+
|
159 |
+
Args:
|
160 |
+
u (np.ndarray): Input horizontal flow of shape [H,W]
|
161 |
+
v (np.ndarray): Input vertical flow of shape [H,W]
|
162 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
166 |
+
"""
|
167 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
168 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
169 |
+
ncols = colorwheel.shape[0]
|
170 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
171 |
+
a = np.arctan2(-v, u)/np.pi
|
172 |
+
fk = (a+1) / 2*(ncols-1)
|
173 |
+
k0 = np.floor(fk).astype(np.int32)
|
174 |
+
k1 = k0 + 1
|
175 |
+
k1[k1 == ncols] = 0
|
176 |
+
f = fk - k0
|
177 |
+
for i in range(colorwheel.shape[1]):
|
178 |
+
tmp = colorwheel[:,i]
|
179 |
+
col0 = tmp[k0] / 255.0
|
180 |
+
col1 = tmp[k1] / 255.0
|
181 |
+
col = (1-f)*col0 + f*col1
|
182 |
+
idx = (rad <= 1)
|
183 |
+
col[idx] = 1 - rad[idx] * (1-col[idx])
|
184 |
+
col[~idx] = col[~idx] * 0.75 # out of range
|
185 |
+
# Note the 2-i => BGR instead of RGB
|
186 |
+
ch_idx = 2-i if convert_to_bgr else i
|
187 |
+
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
188 |
+
return flow_image
|
189 |
+
|
190 |
+
|
191 |
+
def flow_to_img_scflow(flow_uv, clip_flow=None):
|
192 |
+
"""
|
193 |
+
Expects a two dimensional flow image of shape.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
197 |
+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
201 |
+
"""
|
202 |
+
convert_to_bgr = False
|
203 |
+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
204 |
+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
205 |
+
if clip_flow is not None:
|
206 |
+
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
207 |
+
u = flow_uv[:,:,0]
|
208 |
+
v = flow_uv[:,:,1]
|
209 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
210 |
+
rad_max = np.max(rad)
|
211 |
+
epsilon = 1e-5
|
212 |
+
u = u / (rad_max + epsilon)
|
213 |
+
v = v / (rad_max + epsilon)
|
214 |
+
return flow_uv_to_colors_scflow(u, v, convert_to_bgr)
|
215 |
+
|
216 |
+
|
217 |
+
#################### EVFlow Version ####################
|
218 |
+
"""
|
219 |
+
From https://github.com/chan8972/Spike-FlowNet/blob/master/vis_utils.py
|
220 |
+
"""
|
221 |
+
|
222 |
+
"""
|
223 |
+
Generates an RGB image where each point corresponds to flow in that direction from the center,
|
224 |
+
as visualized by flow_viz_tf.
|
225 |
+
Output: color_wheel_rgb: [1, width, height, 3]
|
226 |
+
"""
|
227 |
+
def draw_color_wheel_np(width, height):
|
228 |
+
color_wheel_x = np.linspace(-width / 2.,width / 2.,width)
|
229 |
+
color_wheel_y = np.linspace(-height / 2.,height / 2.,height)
|
230 |
+
color_wheel_X, color_wheel_Y = np.meshgrid(color_wheel_x, color_wheel_y)
|
231 |
+
color_wheel_rgb = flow_viz_np(color_wheel_X, color_wheel_Y)
|
232 |
+
return color_wheel_rgb
|
233 |
+
|
234 |
+
|
235 |
+
"""
|
236 |
+
Visualizes optical flow in HSV space using TensorFlow, with orientation as H, magnitude as V.
|
237 |
+
Returned as RGB.
|
238 |
+
Input: flow: [batch_size, width, height, 2]
|
239 |
+
Output: flow_rgb: [batch_size, width, height, 3]
|
240 |
+
"""
|
241 |
+
def flow_viz_np(flow_x, flow_y):
|
242 |
+
import cv2
|
243 |
+
flows = np.stack((flow_x, flow_y), axis=2)
|
244 |
+
mag = np.linalg.norm(flows, axis=2)
|
245 |
+
|
246 |
+
ang = np.arctan2(flow_y, flow_x)
|
247 |
+
ang += np.pi
|
248 |
+
ang *= 180. / np.pi / 2.
|
249 |
+
ang = ang.astype(np.uint8)
|
250 |
+
hsv = np.zeros([flow_x.shape[0], flow_x.shape[1], 3], dtype=np.uint8)
|
251 |
+
hsv[:, :, 0] = ang
|
252 |
+
hsv[:, :, 1] = 255
|
253 |
+
hsv[:, :, 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
|
254 |
+
flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
|
255 |
+
return flow_rgb
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
#################### Visualization tools when training SCFlow ####################
|
260 |
+
def outflow_img(flow_list, vis_path, name_prefix='flow', max_batch=4):
|
261 |
+
flow = flow_list[0]
|
262 |
+
batch_size, c, h, w = flow.shape
|
263 |
+
|
264 |
+
for batch in range(batch_size):
|
265 |
+
if batch > max_batch:
|
266 |
+
break
|
267 |
+
flow_current = flow[batch,:,:,:].permute(1,2,0).detach().cpu().numpy()
|
268 |
+
flow_img = flow_visualization(flow_current, mode='scflow', use_cv2=True)
|
269 |
+
|
270 |
+
cv2.imwrite(vis_path + '/{:s}_batch_id={:02d}.png'.format(name_prefix, batch), flow_img)
|
271 |
+
|
272 |
+
return
|