zzzzzeee commited on
Commit
9fa5305
·
verified ·
1 Parent(s): 487a673

Upload 28 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ snnTracker/0.dat filter=lfs diff=lfs merge=lfs -text
37
+ snnTracker/1.dat filter=lfs diff=lfs merge=lfs -text
38
+ snnTracker/2.dat filter=lfs diff=lfs merge=lfs -text
39
+ snnTracker/3.dat filter=lfs diff=lfs merge=lfs -text
40
+ snnTracker/4.dat filter=lfs diff=lfs merge=lfs -text
41
+ snnTracker/5.dat filter=lfs diff=lfs merge=lfs -text
42
+ snnTracker/6.dat filter=lfs diff=lfs merge=lfs -text
43
+ snnTracker/7.dat filter=lfs diff=lfs merge=lfs -text
44
+ snnTracker/8.dat filter=lfs diff=lfs merge=lfs -text
45
+ snnTracker/9.dat filter=lfs diff=lfs merge=lfs -text
46
+ snnTracker/driving_0_snntracker.avi filter=lfs diff=lfs merge=lfs -text
47
+ snnTracker/driving_0_tfi.avi filter=lfs diff=lfs merge=lfs -text
snnTracker/0.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d325c5ecb53e64a6429c753df5e0c3b46deaf3bce97f445499c9dfb64b3db4d
3
+ size 5000000
snnTracker/1.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba866d8427b0baa97a0c09bab6055edceaac2dfff6fcba0580f21973628e5d1a
3
+ size 5000000
snnTracker/2.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a00b47dc8e9e84f902f159cd1f0ab387440091b470e4da9503e18e3bae4d354
3
+ size 5000000
snnTracker/3.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47a782e5eea659f9ddd8ad0e20b688e1e69ed9bf589f99989586397d9e2bd825
3
+ size 5000000
snnTracker/4.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02b795f9ca1f3affc59d59e280051bfd50f24c3d9c1c214ab8a680c9fdc2175a
3
+ size 5000000
snnTracker/5.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68ebdd547fc9c36401a3033701f725d9552de1ae235ea49f9bf04a1afad4a331
3
+ size 5000000
snnTracker/6.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5d0349fab1c1da0d3b3559c8bf30795f2488ec7caa182126525c5eb49f52d10
3
+ size 5000000
snnTracker/7.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a36642ec9cc8ef292cb1223d8cf4c263dde4203da83f0e55179e3570a46754b3
3
+ size 5000000
snnTracker/8.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cea9cb59395d668eab1c18b28cca3821e63f6ed2f097e5a120e67253ee78080a
3
+ size 5000000
snnTracker/9.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50ef6ee3758465fba69b5c453e76cae4a7c805805e646445e8877cc886849ce6
3
+ size 5000000
snnTracker/datasets/0/config.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # 基本配置
3
+ spike_h: 250 # 脉冲数据高度
4
+ spike_w: 400 # 脉冲数据宽度
5
+ is_labeled: false # 是否有标注数据
6
+
7
+ # 数据标识符
8
+ data_field_identifier: '' # 数据文件标识符
9
+ label_field_identifier: '' # 标注文件标识符
10
+
11
+ # 标注数据配置
12
+ labeled_data_type: 'tracking' # 标注数据类型
13
+ labeled_data_suffix: 'txt' # 标注文件后缀
snnTracker/driving_0_snntracker.avi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f738eb42502b754cadf26539f749fe0281ba023cb139f32dff7af2a0a0cadf99
3
+ size 578258
snnTracker/driving_0_tfi.avi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62d6381f1f002c2f68b0e935d69db5ca2545ceb68566c6e6706f613cba6d2af4
3
+ size 9748482
snnTracker/path.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2023/7/16 20:19
3
+ # @Author : Yajing Zheng
4
+ # @Email: [email protected]
5
+ # @File : path.py
6
+ # here put the import lib
7
+ import os
8
+
9
+
10
+ def seek_file(search_dirs, filename):
11
+ search_dir_split = split_path_into_pieces(search_dirs)
12
+ dir_level = len(search_dir_split)
13
+
14
+ for i_dir in range(0, dir_level):
15
+ if i_dir > 0:
16
+ search_dir_split.pop(-1)
17
+ # search_dir = os.path.join(str(search_dir_split[0:-i_dir]))
18
+ search_dir = os.path.join(*search_dir_split)
19
+ for root, dirs, files in os.walk(search_dir):
20
+ if filename in files:
21
+ print('{0}/{1}'.format(root, filename))
22
+ filepath = os.path.join(root, filename)
23
+ return filepath
24
+
25
+
26
+ def split_path_into_pieces(path: str):
27
+ pieces = []
28
+ if path[-1] == '/':
29
+ path = path[0:-1]
30
+
31
+ while True:
32
+ splits = os.path.split(path)
33
+ if splits[0] == '':
34
+ pieces.insert(0, splits[-1])
35
+ break
36
+ if splits[-1] == '':
37
+ pieces.insert(0, splits[0])
38
+ break
39
+ pieces.insert(0, splits[-1])
40
+ path = splits[0]
41
+
42
+ return pieces
43
+
44
+ def replace_identifier(path: list, src: str, dst: str):
45
+ new_path = []
46
+ for piece in path:
47
+ added_piece = piece
48
+ if piece == src:
49
+ added_piece = dst
50
+ new_path.append(added_piece)
51
+
52
+ return new_path
snnTracker/spkData/load_dat.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2023/7/16 20:13
3
+ # @Author : Yajing Zheng
4
+ # @Email: [email protected]
5
+ # @File : load_dat.py
6
+ import os, sys
7
+ import warnings
8
+ import glob
9
+ import yaml
10
+ import numpy as np
11
+ import path
12
+
13
+ # key-value for generate data loader according to the type of label data
14
+ LABEL_DATA_TYPE = {
15
+ 'raw': 0,
16
+ 'reconstruction': 1,
17
+ 'optical_flow': 2,
18
+ 'mono_depth_estimation': 3.1,
19
+ 'stero_depth_estimation': 3.2,
20
+ 'detection': 4,
21
+ 'tracking': 5,
22
+ 'recognition': 6
23
+ }
24
+
25
+
26
+ # generate parameters dictionary according to labeled or not
27
+ def data_parameter_dict(data_filename, label_type):
28
+ filename = path.split_path_into_pieces(data_filename)
29
+
30
+ if os.path.isabs(data_filename):
31
+ file_root = data_filename
32
+ if os.path.isdir(file_root):
33
+ search_root = file_root
34
+ else:
35
+ search_root = '\\'.join(filename[0:-1])
36
+ config_filename = path.seek_file(search_root, 'config.yaml')
37
+ else:
38
+ file_root = os.path.join('', 'datasets', *filename)
39
+ config_filename = os.path.join('', 'datasets', filename[0], 'config.yaml')
40
+
41
+ try:
42
+ with open(config_filename, 'r', encoding='utf-8') as fin:
43
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
44
+ except TypeError as err:
45
+ print("Cannot find config file" + str(err))
46
+ raise err
47
+
48
+ except KeyError as exception:
49
+ print('ERROR! Task name does not exist')
50
+ print('Task name must be in %s' % LABEL_DATA_TYPE.keys())
51
+ raise exception
52
+
53
+ is_labeled = configs.get('is_labeled')
54
+
55
+ paraDict = {'spike_h': configs.get('spike_h'), 'spike_w': configs.get('spike_w')}
56
+ paraDict['filelist'] = None
57
+
58
+ if is_labeled:
59
+ paraDict['labeled_data_type'] = configs.get('labeled_data_type')
60
+ paraDict['labeled_data_suffix'] = configs.get('labeled_data_suffix')
61
+ paraDict['label_root_list'] = None
62
+
63
+ if os.path.isdir(file_root):
64
+ filelist = sorted(glob.glob(file_root + '/*.dat'), key=os.path.getmtime)
65
+ filepath = filelist[0]
66
+
67
+ labelname = path.replace_identifier(filename, configs.get('data_field_identifier', ''),
68
+ configs.get('label_field_identifier', ''))
69
+ label_root_list = os.path.join('', 'datasets', *labelname)
70
+ paraDict['labeled_data_dir'] = sorted(glob.glob(label_root_list + '/*.' + paraDict['labeled_data_suffix']),
71
+ key=os.path.getmtime)
72
+
73
+ paraDict['filelist'] = filelist
74
+ paraDict['label_root_list'] = label_root_list
75
+ else:
76
+ filepath = glob.glob(file_root)[0]
77
+ rawname = filename[-1].replace('.dat', '')
78
+ filename.pop(-1)
79
+ filename.append(rawname)
80
+ labelname = path.replace_identifier(filename, configs.get('data_field_identifier', ''),
81
+ configs.get('label_field_identifier', ''))
82
+ label_root = os.path.join('', 'datasets', *labelname)
83
+ paraDict['labeled_data_dir'] = glob.glob(label_root + '.' + paraDict['labeled_data_suffix'])[0]
84
+ else:
85
+ filepath = file_root
86
+
87
+ paraDict['filepath'] = filepath
88
+
89
+ return paraDict
90
+
91
+
92
+ class SpikeStream:
93
+ def __init__(self, **kwargs):
94
+
95
+ self.SpikeMatrix = None
96
+ self.filename = kwargs.get('filepath')
97
+ if os.path.splitext(self.filename)[-1][1:] != 'dat':
98
+ self.filename = self.filename + '.dat'
99
+ self.spike_w = kwargs.get('spike_w')
100
+ self.spike_h = kwargs.get('spike_h')
101
+ if 'print_dat_detail' not in kwargs:
102
+ self.print_dat_detail = True
103
+ else:
104
+ self.print_dat_detail = kwargs.get('print_dat_detail')
105
+
106
+ def get_spike_matrix(self, flipud=True, with_head=False):
107
+
108
+ file_reader = open(self.filename, 'rb')
109
+ video_seq = file_reader.read()
110
+ video_seq = np.frombuffer(video_seq, 'b')
111
+
112
+ video_seq = np.array(video_seq).astype(np.byte)
113
+ if self.print_dat_detail:
114
+ print(video_seq)
115
+ if with_head:
116
+ decode_width = 416
117
+ else:
118
+ decode_width = self.spike_w
119
+ # img_size = self.spike_height * self.spike_width
120
+ img_size = self.spike_h * decode_width
121
+ img_num = len(video_seq) // (img_size // 8)
122
+
123
+ if self.print_dat_detail:
124
+ print('loading total spikes from dat file -- spatial resolution: %d x %d, total timestamp: %d' %
125
+ (decode_width, self.spike_h, img_num))
126
+
127
+ # SpikeMatrix = np.zeros([img_num, self.spike_h, self.spike_width], np.byte)
128
+
129
+ pix_id = np.arange(0, img_num * self.spike_h * decode_width)
130
+ pix_id = np.reshape(pix_id, (img_num, self.spike_h, decode_width))
131
+ comparator = np.left_shift(1, np.mod(pix_id, 8))
132
+ byte_id = pix_id // 8
133
+
134
+ data = video_seq[byte_id]
135
+ result = np.bitwise_and(data, comparator)
136
+ tmp_matrix = (result == comparator)
137
+
138
+ # if with head, delete them
139
+ if with_head:
140
+ delete_indx = np.arange(400, 416)
141
+ tmp_matrix = np.delete(tmp_matrix, delete_indx, 2)
142
+
143
+ if flipud:
144
+ self.SpikeMatrix = tmp_matrix[:, ::-1, :]
145
+ else:
146
+ self.SpikeMatrix = tmp_matrix
147
+
148
+ file_reader.close()
149
+ self.SpikeMatrix = self.SpikeMatrix.astype(np.byte)
150
+ return self.SpikeMatrix
151
+
152
+ # return spikes with specified length and begin index
153
+
154
+ def get_block_spikes(self, begin_idx, block_len=1, flipud=True, with_head=False):
155
+
156
+ file_reader = open(self.filename, 'rb')
157
+ video_seq = file_reader.read()
158
+ video_seq = np.frombuffer(video_seq, 'b')
159
+
160
+ video_seq = np.array(video_seq).astype(np.uint8)
161
+
162
+ if with_head:
163
+ decode_width = 416
164
+ else:
165
+ decode_width = self.spike_w
166
+ # img_size = self.spike_height * self.spike_width
167
+ img_size = self.spike_h * decode_width
168
+ img_num = len(video_seq) // (img_size // 8)
169
+
170
+ end_idx = begin_idx + block_len
171
+ if end_idx > img_num:
172
+ warnings.warn("block_len exceeding upper limit! Zeros will be padded in the end. ", ResourceWarning)
173
+ end_idx = img_num
174
+
175
+ if self.print_dat_detail:
176
+ print(
177
+ 'loading total spikes from dat file -- spatial resolution: %d x %d, begin index: %d total timestamp: %d' %
178
+ (decode_width, self.spike_h, begin_idx, block_len))
179
+
180
+ pix_id = np.arange(0, block_len * self.spike_h * decode_width)
181
+ pix_id = np.reshape(pix_id, (block_len, self.spike_h, decode_width))
182
+ comparator = np.left_shift(1, np.mod(pix_id, 8))
183
+ byte_id = pix_id // 8
184
+ id_start = begin_idx * img_size // 8
185
+ id_end = id_start + block_len * img_size // 8
186
+ data = video_seq[id_start:id_end]
187
+ data_frame = data[byte_id]
188
+ result = np.bitwise_and(data_frame, comparator)
189
+ tmp_matrix = (result == comparator)
190
+
191
+ # if with head, delete them
192
+ if with_head:
193
+ delete_indx = np.arange(400, 416)
194
+ tmp_matrix = np.delete(tmp_matrix, delete_indx, 2)
195
+
196
+ if flipud:
197
+ self.SpikeMatrix = tmp_matrix[:, ::-1, :]
198
+ else:
199
+ self.SpikeMatrix = tmp_matrix
200
+
201
+ file_reader.close()
202
+ self.SpikeMatrix = self.SpikeMatrix.astype(np.byte)
203
+ return self.SpikeMatrix
snnTracker/spkProc/detection/attention_select.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import skimage.morphology as smor
2
+ from skimage.feature import peak_local_max
3
+ from skimage.morphology import erosion
4
+ from skimage.measure import label, regionprops
5
+ import numpy as np
6
+ import torch
7
+ from torchvision.transforms import Resize
8
+
9
+
10
+ # obtain 2D gaussian filter
11
+ def get_kernel(filter_size, sigma):
12
+
13
+ assert (filter_size + 1) % 2 == 0, '2D filter size must be odd number!'
14
+ g = np.zeros((filter_size, filter_size), dtype=np.float32)
15
+ half_width = int((filter_size - 1) / 2)
16
+ # center location
17
+
18
+ xc = (filter_size + 1) / 2
19
+ yc = (filter_size + 1) / 2
20
+ for i in range(-half_width, half_width + 1, 1):
21
+ for j in range(-half_width, half_width + 1, 1):
22
+ x = int(xc + i)
23
+ y = int(yc + j)
24
+ g[y - 1, x - 1] = np.exp(- (i ** 2 + j ** 2) / 2 / sigma / sigma)
25
+
26
+ g = (g - g.min()) / (g.max() - g.min())
27
+ return g
28
+
29
+
30
+ # detect moving connected regions
31
+ class SaccadeInput:
32
+
33
+ def __init__(self, spike_h, spike_w, box_size, device, attentionThr=None, extend_edge=None):
34
+
35
+ self.spike_h = spike_h
36
+ self.spike_w = spike_w
37
+ self.device = device
38
+
39
+ self.U = torch.zeros(self.spike_h, self.spike_w, dtype=torch.float32)
40
+ self.tau_u = 0.5
41
+ self.global_inih = 0.01
42
+ self.box_width = box_size # attention box width
43
+ self.Jxx_size = self.box_width * 2 + 1
44
+ self.Jxx = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(self.Jxx_size, self.Jxx_size),
45
+ padding=(self.box_width, self.box_width), bias=False)
46
+
47
+ tmp_filter = get_kernel(self.Jxx_size, round(self.box_width / 2) + 1)
48
+ tmp_filter = tmp_filter.reshape((1, 1, self.Jxx_size, self.Jxx_size))
49
+ self.Jxx.weight.data = torch.from_numpy(tmp_filter)
50
+ self.resizer = Resize((self.Jxx_size, self.Jxx_size))
51
+
52
+ self.U = self.U.to(self.device)
53
+ self.Jxx = self.Jxx.to(self.device)
54
+
55
+ if attentionThr is not None:
56
+ self.attentionThr = attentionThr
57
+ else:
58
+ self.attentionThr = 40
59
+ if extend_edge is not None:
60
+ self.extend_edge = extend_edge
61
+ else:
62
+ self.extend_edge = 7
63
+ # self.extend_edge = 1
64
+ self.peak_width = int(self.extend_edge)
65
+
66
+ def update_dnf(self, spike):
67
+ inputSpk = torch.reshape(spike, (1, 1, self.spike_h, self.spike_w)).float()
68
+
69
+ maxU = torch.relu(self.U)
70
+ squareU = torch.square(maxU)
71
+ r = squareU / (1 + self.global_inih * torch.sum(squareU))
72
+ conv_fired = self.Jxx(inputSpk)
73
+ conv_fired = torch.squeeze(conv_fired).to(self.device)
74
+ du = conv_fired - self.U
75
+
76
+ r = torch.reshape(r, (1, 1, self.spike_h, self.spike_w))
77
+ conv_r = self.Jxx(r)
78
+ conv_r = torch.squeeze(conv_r).to(self.device)
79
+ du = conv_r + du
80
+ self.U += (du * self.tau_u).detach()
81
+
82
+ del inputSpk, maxU, squareU, r, conv_r, conv_fired, du
83
+
84
+ def get_attention_location(self, spikes):
85
+
86
+ tmpU = torch.relu(self.U - self.attentionThr)
87
+ tmpU = tmpU.cpu()
88
+ tmpU = tmpU.detach().numpy()
89
+ dilated_u = erosion(tmpU, smor.square(self.peak_width))
90
+ peak_cord = peak_local_max(dilated_u, min_distance=self.box_width)
91
+ num_max = len(peak_cord)
92
+ # print('detect %d attention location' % num_max)
93
+ dilated_u[dilated_u > 1] = 1
94
+ dilated_u[dilated_u < 1] = 0
95
+ region_labels = label(dilated_u)
96
+ regions = regionprops(region_labels)
97
+ num_box = len(regions)
98
+
99
+ attentionBox = torch.zeros((num_box, 4), dtype=torch.int)
100
+ attentionInput = torch.zeros(self.Jxx_size + 4, self.Jxx_size, num_box)
101
+
102
+ for region, iBox in zip(regions, range(num_box)):
103
+ minr, minc, maxr, maxc = region.bbox
104
+ beginX = minr - self.extend_edge >= 0 and minr - self.extend_edge or 0
105
+ beginY = minc - self.extend_edge >= 0 and minc - self.extend_edge or 0
106
+ endX = maxr + self.extend_edge < self.spike_h and maxr + self.extend_edge or self.spike_h - 1
107
+ endY = maxc + self.extend_edge < self.spike_w and maxc + self.extend_edge or self.spike_w - 1
108
+
109
+ attentionBox[iBox, :] = torch.tensor([beginX, beginY, endX, endY])
110
+ attentionI = torch.unsqueeze(spikes[beginX:endX + 1, beginY:endY + 1], dim=0)
111
+ attentionI = self.resizer.forward(attentionI)
112
+ fire_index = torch.where(attentionI > 0.9)
113
+ attentionI2 = torch.zeros_like(attentionI)
114
+ attentionI2[0, fire_index[1], fire_index[2]] = 1
115
+ attentionInput[:-4, :, iBox] = torch.squeeze(attentionI2).detach().clone()
116
+ tmp_spk = bin(beginX + 1)
117
+ tmp_spk = tmp_spk[2:].zfill(self.box_width)
118
+ attentionInput[-4, :-1, iBox] = torch.from_numpy(np.tile(np.array(list(tmp_spk), dtype=np.float32), (1, 2)))
119
+ tmp_spk = bin(beginY + 1)
120
+ tmp_spk = tmp_spk[2:].zfill(self.box_width)
121
+ attentionInput[-3, :-1, iBox] = torch.from_numpy(np.tile(np.array(list(tmp_spk), dtype=np.float32), (1, 2)))
122
+ tmp_spk = bin(endX + 1)
123
+ tmp_spk = tmp_spk[2:].zfill(self.box_width)
124
+ attentionInput[-2, :-1, iBox] = torch.from_numpy(np.tile(np.array(list(tmp_spk), dtype=np.float32), (1, 2)))
125
+ tmp_spk = bin(endY + 1)
126
+ tmp_spk = tmp_spk[2:].zfill(self.box_width)
127
+ attentionInput[-1, :-1, iBox] = torch.from_numpy(np.tile(np.array(list(tmp_spk), dtype=np.float32), (1, 2)))
128
+
129
+ attentionInput = attentionInput.to(self.device)
130
+ del tmpU, dilated_u, peak_cord
131
+
132
+ return attentionBox, attentionInput
snnTracker/spkProc/detection/motion_clustering.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/12/6 3:34
3
+ # @Author : Yajing Zheng
4
+ # @Email: [email protected]
5
+ # @File : motion_clustering.py
6
+ from sklearn.cluster import DBSCAN, OPTICS, SpectralClustering
7
+ from sklearn.metrics import pairwise_distances
8
+ from sklearn.preprocessing import StandardScaler
9
+ import numpy as np
10
+ import scipy.ndimage.measurements as mnts
11
+ import torch
12
+
13
+
14
+ class detect_object:
15
+
16
+ def __init__(self, h, w):
17
+ self.h = h
18
+ self.w = w
19
+ params = {'quantile': .3,
20
+ 'eps': .4,
21
+ 'damping': .9,
22
+ 'preference': -200,
23
+ 'n_neighbors': 10,
24
+ 'min_samples': 50,
25
+ 'xi': 0.05,
26
+ 'min_cluster_size': 0.1,
27
+ 'n_cluster': 2}
28
+ self.dbscan = DBSCAN(eps=params['eps'], min_samples=params['min_samples'], metric='precomputed')
29
+ self.optics = OPTICS(min_samples=params['min_samples'], xi=params['xi'],
30
+ min_cluster_size=params['min_cluster_size'], metric='precomputed')
31
+ self.spectral = SpectralClustering(n_clusters=params['n_cluster'], eigen_solver='arpack',
32
+ affinity='precomputed')
33
+
34
+ def get_object(self, motion_vector, max_motion=None):
35
+ # motion_vector = StandardScaler.transform(motion_vector)
36
+ if max_motion is None:
37
+ mv_idx = torch.where(torch.logical_or(motion_vector[:, :, 0] != 0, motion_vector[:, :, 1] != 0))
38
+ if len(mv_idx[0]) < 1:
39
+ return None, None
40
+ fire_idx = np.zeros((2, len(mv_idx[0])), dtype=np.int)
41
+ fire_idx[0, :] = mv_idx[0].cpu().numpy()
42
+ fire_idx[1, :] = mv_idx[1].cpu().numpy()
43
+
44
+ else:
45
+ max_motion = max_motion.cpu().numpy()
46
+ if max_motion.max() < 1:
47
+ return None, None
48
+ fire_idx = np.array(np.nonzero(max_motion))
49
+
50
+ motion_vector = motion_vector.cpu().numpy()
51
+
52
+ fire_idx = fire_idx.T
53
+ num_events = len(fire_idx)
54
+ fire_idx_ori = fire_idx
55
+ spatial_vector = StandardScaler().fit_transform(fire_idx)
56
+
57
+ motion_array = np.zeros((num_events, 2))
58
+ motion_array[:, 0] = motion_vector[fire_idx[:, 0], fire_idx[:, 1], 0]
59
+ motion_array[:, 1] = motion_vector[fire_idx[:, 0], fire_idx[:, 1], 1]
60
+ motion_array = StandardScaler().fit_transform(motion_array)
61
+
62
+ motion_dis = pairwise_distances(motion_array, metric='euclidean')
63
+ spatial_dis = pairwise_distances(spatial_vector, metric='euclidean')
64
+ total_dis = 0.5 * (motion_dis + spatial_dis)
65
+
66
+ self.dbscan.fit(total_dis)
67
+ # self.optics.fit(total_dis)
68
+ # self.spectral.fit(total_dis)
69
+ labels = self.dbscan.labels_.astype(np.int)
70
+ # labels_optics = self.optics.labels_.astype(np.int)
71
+ # labels_spetral = self.spectral.labels_.astype(np.int)
72
+
73
+ return labels, fire_idx_ori
74
+ # return labels_optics, fire_idx_ori
75
+ # return labels_spetral, fire_idx_ori
76
+
77
+ def detection_object_with_motion(self, fireID, clusterId):
78
+ L = np.zeros((self.h, self.w), dtype=np.int)
79
+ L[fireID[:, 0], fireID[:, 1]] = clusterId + 1
80
+
81
+ structure = np.array([
82
+ [1, 1, 1],
83
+ [1, 1, 1],
84
+ [1, 1, 1]
85
+ ])
86
+
87
+ bboxSlices = mnts.find_objects(L)
88
+ box_num = clusterId.max() + 1
89
+ bbox = np.zeros((box_num, 4))
90
+ for iBox in range(box_num):
91
+ tmpBox = np.array(bboxSlices[iBox])
92
+ begin_X = tmpBox[0].start
93
+ end_X = tmpBox[0].stop
94
+ begin_Y = tmpBox[1].start
95
+ end_Y = tmpBox[1].stop
96
+
97
+ bbox[iBox, :] = [begin_X, begin_Y, end_X, end_Y]
98
+
99
+ # pprint(bbox)
100
+ return bbox
101
+
snnTracker/spkProc/detection/stdp_clustering.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from config import *
2
+ import numpy as np
3
+ import torch
4
+ from collections import namedtuple
5
+
6
+ detect_box = namedtuple('detect_box', ['zId', 'box', 'velocity'])
7
+ tracks = namedtuple('tracks', ['id', 'color', 'bbox', 'predbox', 'visible', 'vel', 'age', 'unvisible_count'])
8
+ trajectories = namedtuple('trajectories', ['id', 'x', 'y', 't', 'color'])
9
+ tracks_bbox = namedtuple('tracks_bbox', ['id', 't', 'x', 'y', 'h', 'w'])
10
+
11
+
12
+ class stdp_cluster():
13
+
14
+ def __init__(self, spike_h, spike_w, box_size, device):
15
+
16
+ self.spike_h = spike_h
17
+ self.spike_w = spike_w
18
+ self.box_size = box_size
19
+ self.K1 = 1
20
+ self.K2 = 5
21
+ # self.InputSize = (2 * box_size + 1)**2
22
+ self.InputSize = (2 * box_size + 1) * (2 * box_size + 5)
23
+ self.device = device
24
+ # self.InputSize = box_size * (box_size + 4)
25
+ # self.InputSize = box_size**2
26
+
27
+ # self.synaptic_weight = torch.ones(self.K1, self.InputSize, dtype=torch.float64) / self.K1
28
+ self.synaptic_weight = torch.rand(self.K2, self.K1, self.InputSize, dtype=torch.float64)
29
+ # self.synaptic_weight = self.synaptic_weight / torch.sum(self.synaptic_weight)
30
+ # self.synaptic_weight = torch.unsqueeze(self.synaptic_weight, dim=0)
31
+ # self.synaptic_weight = self.synaptic_weight.repeat(self.K2, 1, 1)
32
+ self.bias_weight = torch.ones(self.K2, 1, dtype=torch.float64) / self.K2
33
+
34
+ self.synaptic_weight = self.normalization_w(self.synaptic_weight)
35
+ self.bias_weight = self.normalization_w(self.bias_weight)
36
+ self.synaptic_weight = self.synaptic_weight.to(device)
37
+ self.bias_weight = self.bias_weight.to(device)
38
+
39
+ self.learning_rate = 0.001
40
+ self.iter_num = 5
41
+ self.stdp_coefficience = 1
42
+
43
+ self.w_up = 1
44
+ self.w_low = -8
45
+
46
+ self.w_up_tensor = torch.ones_like(self.synaptic_weight, dtype=torch.float64) * self.w_up
47
+ self.w_low_tensor = torch.ones_like(self.synaptic_weight, dtype=torch.float64) * self.w_low
48
+
49
+ self.tracks = []
50
+ self.trajectories = []
51
+ self.tracks_bbox = []
52
+ # self.seed_everything(5)
53
+
54
+ self.background_occ_fr = 10 # background oscillation rate 20 Hz
55
+ self.occ_fr = torch.Tensor([self.background_occ_fr / 20000.0])
56
+
57
+ for i_neuron in range(self.K2):
58
+ self.tracks.append(tracks(i_neuron, 255 * np.random.rand(1, 3), torch.zeros((1, 4), dtype=torch.float64),
59
+ torch.zeros((1, 4), dtype=torch.float64), 0,
60
+ torch.zeros((2,)), 0, 0))
61
+ self.trajectories.append(trajectories(i_neuron, [], [], [], self.tracks[i_neuron].color))
62
+ self.tracks_bbox.append(tracks_bbox(i_neuron, [], [], [], [], []))
63
+
64
+
65
+ def normalization_w(self, weight):
66
+
67
+ if len(weight.shape) == 2:
68
+ exp_w = torch.exp(weight)
69
+ if torch.sum(exp_w) == 0:
70
+ norm_w = weight.clone()
71
+ else:
72
+ exp_w = exp_w / torch.sum(exp_w)
73
+ norm_w = torch.log(exp_w)
74
+
75
+ del exp_w
76
+ else:
77
+ exp_w = torch.exp(weight)
78
+ w_norm = torch.sum(exp_w, dim=2)
79
+ w_norm = torch.unsqueeze(w_norm, dim=2)
80
+ w_norm = w_norm.tile(1, 1, self.InputSize)
81
+ norm_w = exp_w.detach().clone()
82
+
83
+ valid_index = torch.where(w_norm != 0)
84
+ norm_w[valid_index] = exp_w[valid_index] / w_norm[valid_index]
85
+ norm_w = torch.log(norm_w)
86
+
87
+ del exp_w, w_norm, valid_index
88
+
89
+ torch.cuda.empty_cache()
90
+ return norm_w
91
+
92
+ # winner-take-all
93
+ # @staticmethod
94
+ def wta(self, attention_spikes, synaptic_weight, bias_weight):
95
+
96
+ attention_spikes = attention_spikes.double()
97
+ synaptic_weight = torch.squeeze(synaptic_weight + abs(self.w_low))
98
+ # synaptic_weight = torch.squeeze(synaptic_weight)
99
+ intPre = torch.matmul(synaptic_weight, attention_spikes)
100
+ intPre = torch.squeeze(intPre)
101
+ intPre[intPre > 700] = 700
102
+ psp = intPre + torch.squeeze(bias_weight)
103
+
104
+ rate_norm = (torch.Tensor([1]) * 1.0).to(self.device)
105
+ sum_exp = torch.sum(torch.exp(psp))
106
+ fire_inhb = torch.log(sum_exp) - torch.log(rate_norm)
107
+
108
+ tmp_psp = torch.exp(psp - fire_inhb)
109
+ tmp_psp[torch.isinf(tmp_psp)] = 0
110
+ fire_index = torch.where(torch.rand(1).to(self.device) < tmp_psp)
111
+ Z_spike = torch.zeros(tmp_psp.shape)
112
+ Z_spike[fire_index] = 1
113
+
114
+ del intPre, psp, rate_norm, sum_exp, fire_inhb, tmp_psp, fire_index
115
+ torch.cuda.empty_cache()
116
+ return Z_spike
117
+
118
+ @staticmethod
119
+ def intersect(box_a, box_b):
120
+ """ We resize both tensors to [A,B,2] without new malloc:
121
+ [A,2] -> [A,1,2] -> [A,B,2]
122
+ [B,2] -> [1,B,2] -> [A,B,2]
123
+ Then we compute the area of intersect between box_a and box_b.
124
+ Args:
125
+ box_a: (tensor) bounding boxes, Shape: [A,4].
126
+ box_b: (tensor) bounding boxes, Shape: [B,4].
127
+ Return:
128
+ (tensor) intersection area, Shape: [A,B].
129
+ """
130
+ A = box_a.size(0)
131
+ B = box_b.size(0)
132
+ max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
133
+ box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
134
+ min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
135
+ box_b[:, :2].unsqueeze(0).expand(A, B, 2))
136
+ inter = torch.clamp((max_xy - min_xy), min=0)
137
+ return inter[:, :, 0] * inter[:, :, 1]
138
+
139
+ def jaccard(self, box_a, box_b):
140
+ """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
141
+ is simply the intersection over union of two boxes. Here we operate on
142
+ ground truth boxes and default boxes.
143
+ E.g.:
144
+ A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
145
+ Args:
146
+ box_a: (tensor) Ground truth bounding boxes, Shape: [A,4]
147
+ box_b: (tensor) Prior boxes from priorbox layers, Shape: [B,4]
148
+ Return:
149
+ jaccard overlap: (tensor) Shape: [A, B]
150
+ """
151
+ inter = self.intersect(box_a, box_b)
152
+ area_a = ((box_a[:, 2] - box_a[:, 0]) *
153
+ (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
154
+ area_b = ((box_b[:, 2] - box_b[:, 0]) *
155
+ (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
156
+ union = area_a + area_b - inter
157
+ return inter / union # [A,B]
158
+
159
+ def update_weight(self, attention_input):
160
+
161
+ n_attention = attention_input.shape[2]
162
+ predict_fire = torch.zeros(n_attention, self.K2)
163
+ synaptic_weight = self.synaptic_weight.detach().clone()
164
+ bias_weight = self.bias_weight.detach().clone()
165
+ lr_weight = torch.zeros(self.K2).to(self.device)
166
+ has_fired = np.zeros((self.K2, 1))
167
+
168
+ for iPattern in range(n_attention):
169
+ detected = -1
170
+ input_spike = torch.reshape(attention_input[:, :, iPattern], (-1, 1))
171
+ # background_noise = (torch.rand(input_spike.shape) < self.occ_fr).to(device)
172
+ # input_spike = (torch.logical_or((input_spike).type(torch.bool), background_noise)).type(torch.float32)
173
+ # input_spike = torch.reshape(attention_input, (-1, n_attention))
174
+ confusion_flag = 0
175
+ for i in range(self.iter_num):
176
+ z_spike = self.wta(input_spike, synaptic_weight, bias_weight).to(self.device)
177
+ dw_bias = self.learning_rate * (z_spike * torch.squeeze(torch.exp(-bias_weight)) - 1)
178
+ # tmp_sum = torch.sum(dw_bias, dim=1)
179
+ bias_weight += torch.unsqueeze(dw_bias, dim=1).detach()
180
+
181
+ for iZ in range(self.K2):
182
+ if z_spike[iZ] != 0 and has_fired[iZ] == 0 and (detected == -1 or iZ == detected):
183
+ has_fired[iZ] = 1
184
+ detected = iZ
185
+ # fire_idx = torch.where(z_spike[iZ, :]!=0)
186
+ tmpE = torch.exp(-synaptic_weight[iZ, :, :])
187
+ dw = self.stdp_coefficience * tmpE * torch.transpose(input_spike, 0, 1) - 1
188
+ lr_weight[iZ] += 1
189
+
190
+ synaptic_weight[iZ, :, :] += ((1.0 / lr_weight[iZ]) * dw.to(self.device))
191
+ # synaptic_weight[iZ, :, :] += self.learning_rate * dw.to(device)
192
+
193
+ synaptic_weight = torch.where(synaptic_weight < self.w_up, synaptic_weight,
194
+ self.w_up_tensor).detach()
195
+ synaptic_weight = torch.where(synaptic_weight < self.w_low, self.w_low_tensor,
196
+ synaptic_weight).detach()
197
+
198
+ synaptic_weight = self.normalization_w(synaptic_weight)
199
+ bias_weight = self.normalization_w(bias_weight)
200
+ predict_fire[iPattern, iZ] = torch.Tensor([1])
201
+
202
+ # predict_fire[iPattern, :] = z_spike.detach()
203
+ # predict_fire = z_spike.detach()
204
+ # print(synaptic_weight.max())
205
+ # print(synaptic_weight.min())
206
+
207
+ del n_attention, lr_weight
208
+ torch.cuda.empty_cache()
209
+ return predict_fire, synaptic_weight, bias_weight
210
+
211
+ def seed_everything(self, seed=11):
212
+ np.random.seed(seed)
213
+ torch.manual_seed(seed)
214
+ torch.cuda.manual_seed(seed)
215
+ torch.cuda.manual_seed_all(seed)
216
+ torch.backends.cudnn.deterministic = True
217
+ torch.backends.cudnn.benchmark = False
218
+
219
+ return
220
+
221
+ @staticmethod
222
+ def detect_object(predict_fire, attention_box, motion_id, motion_vector, **kwargs):
223
+
224
+ spike_h = kwargs.get('spike_h')
225
+ spike_w = kwargs.get('spike_w')
226
+ device = kwargs.get('device')
227
+ nAttention = attention_box.shape[0]
228
+ boxId = torch.zeros(nAttention, 1)
229
+ predBox = torch.zeros((nAttention, 4), dtype=torch.int)
230
+ velocities = torch.zeros(nAttention, 2).to(device)
231
+ predict_box = []
232
+
233
+ for iPattern in range(nAttention):
234
+ z_spike = predict_fire[iPattern, :]
235
+ if torch.any(z_spike != 0):
236
+ if len(torch.where(z_spike != 0)[0]) > 1:
237
+ print('check')
238
+
239
+ tmp_fired = torch.where(z_spike != 0)[0]
240
+ boxId[iPattern] = tmp_fired[0] + 1
241
+
242
+ x = attention_box[iPattern, 0]
243
+ y = attention_box[iPattern, 1]
244
+ end_x = attention_box[iPattern, 2]
245
+ end_y = attention_box[iPattern, 3]
246
+
247
+ tmp_motion = torch.zeros(spike_h, spike_w)
248
+ tmp_motion[x:end_x + 1, y:end_y + 1] = motion_id[x:end_x + 1, y:end_y + 1].clone()
249
+
250
+ motion_index2d = torch.where(tmp_motion != 0)
251
+ if len(motion_index2d[0]) == 0:
252
+ continue
253
+
254
+ motion_num = len(motion_index2d[0])
255
+ block_veloctiy = torch.zeros(motion_num, 2).to(device)
256
+ block_veloctiy[:, 0] = motion_vector[motion_index2d[0], motion_index2d[1], 0].clone()
257
+ block_veloctiy[:, 1] = motion_vector[motion_index2d[0], motion_index2d[1], 1].clone()
258
+ tmp_veloctiy = torch.mean(block_veloctiy, dim=0)
259
+ velocities[iPattern, :] = tmp_veloctiy.data
260
+ predBox[iPattern, :] = attention_box[iPattern, :]
261
+
262
+ predict_box.append(detect_box(boxId[iPattern],
263
+ torch.unsqueeze(predBox[iPattern], dim=0),
264
+ velocities[iPattern]))
265
+ # else:
266
+ # print('no tracking neuron fire..')
267
+
268
+ del boxId, predBox, velocities
269
+ torch.cuda.empty_cache()
270
+
271
+ return predict_box
272
+
273
+ def update_tracks(self, detect_objects, sw, bw, timestep):
274
+
275
+ objects_num = len(detect_objects)
276
+ id_check = torch.zeros(self.K2, 1)
277
+ AssignTrk = []
278
+
279
+ for iObject in range(objects_num):
280
+ tmp_object = detect_objects[iObject]
281
+ id = int(tmp_object.zId.detach().item())
282
+ box = tmp_object.box
283
+ velocity = tmp_object.velocity
284
+
285
+ if id_check[id - 1] != 0:
286
+ if id in AssignTrk:
287
+ # print('id %d repeat' % (id-1))
288
+ AssignTrk.remove(id)
289
+ continue
290
+ else:
291
+ id_check[id - 1] = 1
292
+
293
+ pred_box = self.tracks[id - 1].predbox
294
+ boxes_iou = self.jaccard(box, pred_box)
295
+ unvisible_count = self.tracks[id - 1].unvisible_count
296
+ if ~(self.tracks[id - 1].predbox[0, 3] != 0 and self.tracks[id - 1].age > 15
297
+ and boxes_iou < 0.6):
298
+ self.tracks[id - 1] = self.tracks[id - 1]._replace(bbox=box)
299
+ beginX = box[0, 0]
300
+ beginY = box[0, 1]
301
+ endX = box[0, 2]
302
+ endY = box[0, 3]
303
+
304
+ beginX = beginX + velocity[0] >= 0 and (beginX + velocity[0]) or 0
305
+ beginY = beginY + velocity[1] >= 0 and (beginY + velocity[1]) or 0
306
+ # endX = beginX + self.box_size * 2 < self.spike_h and (beginX + self.box_size * 2) or (self.spike_h - 1)
307
+ # endY = beginY + self.box_size * 2 < self.spike_w and (beginY + self.box_size * 2) or (self.spike_w - 1)
308
+ endX = endX + velocity[0] < self.spike_h and (endX + velocity[0]) or (self.spike_h - 1)
309
+ endY = endY + velocity[1] < self.spike_w and (endY + velocity[1]) or (self.spike_w - 1)
310
+
311
+ tmp_box = torch.tensor([beginX, beginY, endX, endY])
312
+ tmp_box = torch.unsqueeze(tmp_box, dim=0)
313
+ self.tracks[id - 1] = self.tracks[id - 1]._replace(predbox=tmp_box)
314
+
315
+ self.tracks[id - 1] = self.tracks[id - 1]._replace(visible=1)
316
+ self.tracks[id - 1] = self.tracks[id - 1]._replace(vel=velocity)
317
+ self.tracks[id - 1] = self.tracks[id - 1]._replace(unvisible_count=0)
318
+ self.tracks[id - 1] = self.tracks[id - 1]._replace(age=self.tracks[id - 1].age + 1)
319
+
320
+ # update the trajectories
321
+ self.trajectories[id - 1].x.append((box[0, 0] + self.box_size).item())
322
+ self.trajectories[id - 1].y.append((box[0, 1] + self.box_size).item())
323
+ self.trajectories[id - 1].t.append(timestep)
324
+
325
+ # Check if beginX, beginY, endX, endY are int; otherwise, use .item()
326
+ self.tracks_bbox[id - 1].x.append(beginY if isinstance(beginY, int) else beginY.item())
327
+ self.tracks_bbox[id - 1].y.append(beginX if isinstance(beginX, int) else beginX.item())
328
+ self.tracks_bbox[id - 1].h.append(
329
+ (endX - beginX) if isinstance(endX, int) and isinstance(beginX, int) else (endX - beginX).item())
330
+ self.tracks_bbox[id - 1].w.append(
331
+ (endY - beginY) if isinstance(endY, int) and isinstance(beginY, int) else (endY - beginY).item())
332
+
333
+ self.tracks_bbox[id - 1].t.append(timestep)
334
+ AssignTrk.append(id)
335
+ # print('tracks %d velocity dx: %f dy: %f' % (id, velocity[0], velocity[1]))
336
+
337
+ all_id = list(range(1, self.K2 + 1, 1))
338
+ noAssign = [x for x in all_id if x not in AssignTrk]
339
+
340
+ noAssign_num = self.K2 - len(AssignTrk)
341
+
342
+ for iObject in range(noAssign_num):
343
+ id = noAssign[iObject]
344
+ unvisible_count = self.tracks[id - 1].unvisible_count
345
+ if unvisible_count > 5:
346
+ self.tracks[id - 1] = self.tracks[id - 1]._replace(age=0)
347
+ self.tracks[id - 1] = self.tracks[id - 1]._replace(visible=0)
348
+ # sw[id-1, :, :] = 1 / self.K1
349
+ # bw[id-1] = 1 / self.K2
350
+ else:
351
+ if self.tracks[id - 1].predbox[0, 2] != 0:
352
+ self.tracks[id - 1] = self.tracks[id - 1]._replace(bbox=self.tracks[id - 1].predbox)
353
+ beginX = self.tracks[id - 1].predbox[0, 0].item()
354
+ beginY = self.tracks[id - 1].predbox[0, 1].item()
355
+ endX = self.tracks[id - 1].predbox[0, 2].item()
356
+ endY = self.tracks[id - 1].predbox[0, 3].item()
357
+
358
+ beginX = beginX + self.tracks[id - 1].vel[0] >= 0 and (beginX + self.tracks[id - 1].vel[0]) or 0
359
+ beginY = beginY + self.tracks[id - 1].vel[1] >= 0 and (beginY + self.tracks[id - 1].vel[1]) or 0
360
+ # endX = beginX + self.box_size * 2 < self.spike_h and (beginX + self.box_size * 2) or (self.spike_h - 1)
361
+ # endY = beginY + self.box_size * 2 < self.spike_w and (beginY + self.box_size * 2) or (self.spike_w - 1)
362
+ endX = endX + self.tracks[id - 1].vel[0] < self.spike_h and (endX + self.tracks[id - 1].vel[0]) or (
363
+ self.spike_h - 1)
364
+ endY = endY + self.tracks[id - 1].vel[1] < self.spike_w and (endY + self.tracks[id - 1].vel[1]) or (
365
+ self.spike_w - 1)
366
+
367
+ pred_box = torch.tensor([beginX, beginY, endX, endY])
368
+ pred_box = torch.unsqueeze(pred_box, dim=0)
369
+ self.tracks[id - 1] = self.tracks[id - 1]._replace(predbox=pred_box)
370
+
371
+ self.trajectories[id - 1].x.append(
372
+ (beginX + self.box_size) if isinstance(beginX, int) else (beginX + self.box_size).item())
373
+ self.trajectories[id - 1].y.append(
374
+ (beginY + self.box_size) if isinstance(beginY, int) else (beginY + self.box_size).item())
375
+ self.trajectories[id - 1].t.append(timestep)
376
+
377
+ # Check if beginX, beginY, endX, endY are int; otherwise, use .item()
378
+ self.tracks_bbox[id - 1].x.append(beginY if isinstance(beginY, int) else beginY.item())
379
+ self.tracks_bbox[id - 1].y.append(beginX if isinstance(beginX, int) else beginX.item())
380
+ self.tracks_bbox[id - 1].h.append(
381
+ (endX - beginX) if isinstance(endX, int) and isinstance(beginX, int) else (
382
+ endX - beginX).item())
383
+ self.tracks_bbox[id - 1].w.append(
384
+ (endY - beginY) if isinstance(endY, int) and isinstance(beginY, int) else (
385
+ endY - beginY).item())
386
+ self.tracks_bbox[id - 1].t.append(timestep)
387
+ # print('predicting location of object %d the %d time' % (id, unvisible_count))
388
+
389
+ # print('tracks %d predictive velocity dx: %f, dy: %f' % (
390
+ # id, self.tracks[id - 1].vel[0], self.tracks[id - 1].vel[1]))
391
+ if ~(torch.all(self.synaptic_weight == 1)):
392
+ sw[id - 1, :, :] = self.synaptic_weight[id - 1, :, :].detach().clone()
393
+ bw[id - 1] = self.bias_weight[id - 1].detach().clone()
394
+ # print('correct the weight')
395
+ self.tracks[id - 1] = self.tracks[id - 1]._replace(unvisible_count=self.tracks[id - 1].unvisible_count + 1)
396
+
397
+ return sw, bw
snnTracker/spkProc/filters/stp_filters_torch.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2021/11/19 16:25
3
+ # @Author : Yajing Zheng
4
+ # @File : stp_filters_torch.py
5
+ import copy
6
+
7
+ import torch
8
+ import numpy as np
9
+
10
+ class STPFilter:
11
+
12
+ def __init__(self, spike_h, spike_w, device, diff_time=1, **STPargs):
13
+ self.spike_h = spike_h
14
+ self.spike_w = spike_w
15
+ self.device = device
16
+
17
+ # specify stp parameters
18
+ if STPargs.get('u0', None) is None:
19
+ self.u0 = 0.1
20
+ self.D = 0.02
21
+ self.F = 1.7
22
+ self.f = 0.11
23
+ self.time_unit = 2000
24
+
25
+ else:
26
+ self.u0 = STPargs.get('u0')
27
+ self.D = STPargs.get('D')
28
+ self.F = STPargs.get('F')
29
+ self.f = STPargs.get('f')
30
+ self.time_unit = STPargs.get('time_unit')
31
+
32
+ self.r0 = 1
33
+
34
+ self.diff_time = diff_time # duration of window for record past dynamics for calculating the differnece
35
+ self.R = torch.ones(self.spike_h, self.spike_w) * self.r0
36
+ self.u = torch.ones(self.spike_h, self.spike_w) * self.u0
37
+ self.r_old = torch.ones(self.diff_time, self.spike_h, self.spike_w) * self.r0
38
+
39
+ self.R = self.R.to(self.device)
40
+ self.u = self.u.to(self.device)
41
+ self.r_old = self.r_old.to(self.device)
42
+
43
+ # LIF detect layer parameters
44
+ self.detectVoltage = torch.zeros(self.spike_h, self.spike_w).to(self.device)
45
+ if STPargs.get('lifSize', None) is None:
46
+ lifSize = 3
47
+ paddingSize = 1
48
+ else:
49
+ lifSize = STPargs.get('lifSize')
50
+ paddingSize = int((lifSize - 1) / 2)
51
+
52
+ self.lifConv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(lifSize, lifSize),
53
+ padding=(paddingSize, paddingSize),
54
+ bias=False)
55
+ self.lifConv.weight.data = torch.ones(1, 1, lifSize, lifSize) * 3.0
56
+
57
+ self.lifConv = self.lifConv.to(self.device)
58
+ if STPargs.get('filterThr', None) is None:
59
+ self.filterThr = 0.1 # filter threshold
60
+ self.voltageMin = -8
61
+ self.lifThr = 2
62
+ else:
63
+ self.filterThr = STPargs.get('filterThr')
64
+ self.voltageMin = STPargs.get('voltageMin')
65
+ self.lifThr = STPargs.get('lifThr')
66
+
67
+ self.filter_spk = torch.zeros(self.spike_h, self.spike_w).to(self.device)
68
+ self.lif_spk = torch.zeros(self.spike_h, self.spike_w).to(self.device)
69
+ self.spikePrevMnt = torch.zeros([self.spike_h, self.spike_w], device=self.device)
70
+ self.stp_gradient = 0
71
+ self.adjusted_threshold = torch.zeros(self.spike_h, self.spike_w).to(self.device)
72
+
73
+ def update_dynamics(self, curT, spikes):
74
+
75
+ spikeCurMnt = self.spikePrevMnt.detach().clone()
76
+ spike_bool = spikes.bool()
77
+ spikeCurMnt[spike_bool] = curT + 1
78
+ dttimes = spikeCurMnt - self.spikePrevMnt
79
+ dttimes = dttimes / self.time_unit
80
+ exp_D = torch.exp((-dttimes[spike_bool] / self.D))
81
+ self.R[spike_bool] = 1 - (1 - self.R[spike_bool] * (1 - self.u[spike_bool])) * exp_D
82
+ exp_F = torch.exp((-dttimes[spike_bool] / self.F))
83
+ self.u[spike_bool] = self.u0 + (
84
+ self.u[spike_bool] + self.f * (1 - self.u[spike_bool]) - self.u0) * exp_F
85
+
86
+ tmp_diff = torch.abs(self.R - self.r_old[0])
87
+ # 根据梯度动态调整滤波器阈值
88
+ self.stp_gradient = (0.5 * self.stp_gradient + 0.5 * torch.div(tmp_diff, self.R))
89
+ gradient_sqrt = torch.from_numpy(np.sqrt(self.stp_gradient.cpu().numpy()) + 1).to(self.device)
90
+ self.adjusted_threshold = torch.div(self.filterThr, gradient_sqrt)
91
+
92
+ self.filter_spk[:] = 0
93
+ # self.filter_spk[spike_bool & (tmp_diff >= self.filterThr)] = 1
94
+ self.filter_spk[spike_bool & (tmp_diff >= self.adjusted_threshold)] = 1
95
+
96
+ if curT < self.diff_time:
97
+ self.r_old[curT] = self.R.detach().clone()
98
+ else:
99
+ self.r_old[0:-1] = self.r_old[1:].detach().clone()
100
+ self.r_old[-1] = self.R.detach().clone()
101
+ self.spikePrevMnt = spikeCurMnt.detach().clone()
102
+ del spikeCurMnt, dttimes, exp_D, exp_F, tmp_diff
103
+
104
+ def update_dynamic_offline(self, spikes, intervals):
105
+
106
+ isi_num = intervals.shape[0]
107
+ R = torch.ones(isi_num, self.spike_h, self.spike_w) * self.r0
108
+ u = torch.ones(isi_num, self.spike_h, self.spike_w) * self.u0
109
+ prev_isi = intervals[0, :, :]
110
+
111
+ for t in range(1, isi_num):
112
+ tmp_isi = intervals[t, :, :]
113
+ update_idx = (tmp_isi != prev_isi) & (spikes[t, :, :] == 1) | (tmp_isi == 1)
114
+ tmp_isi = torch.from_numpy(tmp_isi).to(self.device).float()
115
+
116
+ exp_D = torch.exp((-tmp_isi[update_idx] / self.D))
117
+ self.R[update_idx] = 1 - (1 - self.R[update_idx] * (1 - self.u[update_idx])) * exp_D
118
+ exp_F = torch.exp((-tmp_isi[update_idx] / self.F))
119
+ self.u[update_idx] = self.u0 + (
120
+ self.u[update_idx] + self.f * (1 - self.u[update_idx]) - self.u0) * exp_F
121
+
122
+ tmp_r = self.R.detach().clone()
123
+ tmp_u = self.u.detach().clone()
124
+ R[t, :, :] = copy.deepcopy(tmp_r)
125
+ u[t, :, :] = copy.deepcopy(tmp_u)
126
+
127
+ return R, u
128
+
129
+ def local_connect(self, spikes):
130
+ inputSpk = torch.reshape(spikes, (1, 1, self.spike_h, self.spike_w)).float()
131
+ # tmp_fired = spikes != 0
132
+ self.detectVoltage[spikes == False] -= 1
133
+ tmpRes = self.lifConv(inputSpk)
134
+ tmpRes = torch.squeeze(tmpRes).to(self.device)
135
+ self.detectVoltage += tmpRes.data
136
+ self.detectVoltage[self.detectVoltage < self.voltageMin] = self.voltageMin
137
+
138
+ self.lif_spk[:] = 0
139
+ self.lif_spk[self.detectVoltage >= self.lifThr] = 1
140
+ self.detectVoltage[self.detectVoltage >= self.lifThr] *= 0.8
141
+ # self.detectVoltage[(self.detectVoltage < self.lifThr) & (self.detectVoltage > 0)] = 0
142
+
143
+ del inputSpk, tmpRes
144
+
145
+ def local_connect_offline(self, spikes):
146
+ timestamps = spikes.shape[0]
147
+ tmp_voltage = []
148
+ lif_spk = []
149
+
150
+ for iSpk in range(timestamps):
151
+ tmp_spikes = spikes[iSpk]
152
+ tmp_spk = torch.from_numpy(spikes[iSpk]).to(self.device)
153
+ inputSpk = torch.reshape(tmp_spk, (1, 1, self.spike_h, self.spike_w)).float()
154
+ # tmp_fired = spikes != 0
155
+ self.detectVoltage[tmp_spikes == 0] -= 1
156
+ tmpRes = self.lifConv(inputSpk)
157
+ tmpRes = torch.squeeze(tmpRes).to(self.device)
158
+ self.detectVoltage += tmpRes.data
159
+ self.detectVoltage[self.detectVoltage < self.voltageMin] = self.voltageMin
160
+
161
+ self.lif_spk[:] = 0
162
+ self.lif_spk[self.detectVoltage >= self.lifThr] = 1
163
+ # self.detectVoltage[(self.detectVoltage < self.lifThr) & (self.detectVoltage > 0)] = 0
164
+ self.detectVoltage[self.detectVoltage >= self.lifThr] *= 0.8
165
+ voltage = self.detectVoltage.cpu().detach().numpy()
166
+ tmp_voltage.append(copy.deepcopy(voltage))
167
+ lif_spk.append(self.lif_spk.cpu().detach().numpy())
168
+
169
+ del inputSpk, tmpRes
170
+ return tmp_voltage, lif_spk
snnTracker/spkProc/motion/motion_detection.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import get_kernel, get_transform_matrix_new, visualize_images
2
+ import torchgeometry as tgm
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ import torch
6
+
7
+ class motion_estimation:
8
+
9
+ def __init__(self, dvs_h, dvs_w, device, logger):
10
+
11
+ self.dvs_h = dvs_h
12
+ self.dvs_w = dvs_w
13
+ self.device = device
14
+ self.logger = logger
15
+
16
+ # motion parameters
17
+ self.orientation = range(0, 180 - 1, int(180 / 4))
18
+ # eight moving direction
19
+ '''
20
+ self.ori = torch.Tensor(np.array([[-1, -1],
21
+ [0, -1],
22
+ [1, -1],
23
+ [-1, 0],
24
+ [1, 0],
25
+ [-1, 1],
26
+ [0, 1],
27
+ [1, 1]], dtype=np.uint8)).to(self.device)
28
+ '''
29
+ # self.ori = np.array([[-1, -1],
30
+ # [0, -1],
31
+ # [1, -1],
32
+ # [1, 0],
33
+ # [-1, 0],
34
+ # [-1, 1],
35
+ # [0, 1],
36
+ # [1, 1]], dtype=np.int32)
37
+
38
+ self.ori = np.array([[1, 0],
39
+ [1, 1],
40
+ [0, 1],
41
+ [-1, 1],
42
+ [-1, 0],
43
+ [-1, -1],
44
+ [0, -1],
45
+ [1, -1]], dtype=np.int32)
46
+ # self.ori = np.array(self.ori, dtype=np.int)
47
+ # self.speed = torch.Tensor(np.array([1, 2, 3, 4], np.uint8)).to(self.device)
48
+ self.speed = np.array([1, 2, 3, 4, 5, 6], dtype=np.int32)
49
+ # self.speed = np.array([1], dtype=np.int32)
50
+ # self.ori_x = torch.from_numpy(self.ori[:, 0]).to(self.device)
51
+ # self.ori_y = torch.from_numpy(self.ori[:, 1]).to(self.device)
52
+ self.ori_x = torch.from_numpy(np.expand_dims(self.ori[:, 0], axis=1)).to(self.device).float()
53
+ self.ori_y = torch.from_numpy(np.expand_dims(self.ori[:, 1], axis=1)).to(self.device).float()
54
+
55
+ self.warp_matrix = get_transform_matrix_new(self.ori, self.speed, self.dvs_w, self.dvs_h, self.device)
56
+ self.track_pre = torch.zeros(self.dvs_h, self.dvs_w)
57
+
58
+ self.num_ori = len(self.ori)
59
+ self.num_speed = len(self.speed)
60
+ self.motion_pattern_num = self.num_ori * self.num_speed
61
+ self.motion_weight = torch.ones(self.motion_pattern_num, 1, self.dvs_h, self.dvs_w) / self.motion_pattern_num
62
+ self.tracking_threshold = 1
63
+
64
+ # self.local_pool_size = 21
65
+ self.local_pool_size = 11
66
+ # self.local_pool_size = 5
67
+ padding_width = int((self.local_pool_size - 1) / 2)
68
+ self.pool_kernel = torch.nn.Conv2d(in_channels=1, out_channels=1,
69
+ kernel_size=(self.local_pool_size, self.local_pool_size),
70
+ padding=(padding_width, padding_width), bias=False)
71
+ self.pool_kernel.weight.data = torch.ones(1, 1, self.local_pool_size, self.local_pool_size)
72
+
73
+ self.gaussian_kernel = torch.nn.Conv2d(in_channels=1, out_channels=1,
74
+ kernel_size=(self.local_pool_size, self.local_pool_size),
75
+ padding=(padding_width, padding_width), bias=False)
76
+ tmp_filter = get_kernel(self.local_pool_size, round(self.local_pool_size / 4))
77
+ tmp_filter = tmp_filter.reshape((1, 1, self.local_pool_size, self.local_pool_size))
78
+ self.gaussian_kernel.weight.data = torch.from_numpy(tmp_filter).float()
79
+
80
+ # local wta inhibition size
81
+ # inh_size = 15
82
+ self.inh_size = 25
83
+ # inh_size = 11
84
+ self.padding_width = int((self.inh_size - 1) / 2)
85
+ self.inhb_kernel = torch.nn.Conv2d(in_channels=1, out_channels=1,
86
+ kernel_size=(self.inh_size, self.inh_size),
87
+ padding=(self.padding_width, self.padding_width), bias=False)
88
+ self.inhb_kernel.weight.data = torch.ones(1, 1, self.inh_size, self.inh_size)
89
+ self.inhb_threshold = 5
90
+
91
+ self.track_pre = self.track_pre.to(self.device)
92
+ self.motion_weight = self.motion_weight.to(self.device)
93
+ self.pool_kernel = self.pool_kernel.to(self.device)
94
+ self.gaussian_kernel = self.gaussian_kernel.to(self.device)
95
+ self.inhb_kernel = self.inhb_kernel.to(self.device)
96
+
97
+ # cc_motion = [[0, 33, 238],
98
+ # [79, 0, 255],
99
+ # [229, 0, 237],
100
+ # [188, 0, 26],
101
+ # [191, 198, 0],
102
+ # [129, 241, 0],
103
+ # [0, 205, 106],
104
+ # [0, 205, 198]]
105
+
106
+ cc_motion = [[0, 255, 255],
107
+ [205, 95, 85],
108
+ [11, 134, 184],
109
+ [255, 255, 0],
110
+ [154, 250, 0],
111
+ [147, 20, 255],
112
+ [240, 32, 160],
113
+ [48, 48, 255]]
114
+
115
+ cc_motion = np.transpose(np.array(cc_motion, dtype=np.float32))
116
+ self.cc_motion = torch.from_numpy(cc_motion / 255)
117
+ self.cc_motion = self.cc_motion.to(self.device)
118
+ self.learning_rate = 0.1
119
+
120
+ '''
121
+ self.dw_ltp = torch.zeros(self.motion_pattern_num, 1, self.dvs_h, self.dvs_w)
122
+ self.dw_ltd = torch.zeros(self.motion_pattern_num, 1, self.dvs_h, self.dvs_w)
123
+ self.dw_ltp = self.dw_ltp.to(self.device)
124
+ self.dw_ltd = self.dw_ltd.to(self.device)
125
+ '''
126
+
127
+ def stdp_tracking(self, spikes):
128
+ track_post = torch.reshape(spikes, (1, 1, self.dvs_h, self.dvs_w))
129
+ tmp_pool = self.pool_kernel(track_post)
130
+ tmp_pool = tmp_pool.repeat(self.motion_pattern_num, 1, 1, 1)
131
+
132
+ predict_fired = torch.zeros(self.motion_pattern_num, 1, self.dvs_h, self.dvs_w).to(self.device)
133
+ fire_idx = torch.where(spikes != 0)
134
+
135
+ for i_ori in range(self.num_ori):
136
+ for i_speed in range(self.num_speed):
137
+ i_motion = i_ori * self.num_speed + i_speed
138
+ x = fire_idx[0] + self.ori[i_ori, 0] * self.speed[i_speed]
139
+ y = fire_idx[1] + self.ori[i_ori, 1] * self.speed[i_speed]
140
+ invalid_idx = torch.logical_or(torch.logical_or(x > self.dvs_h - 1, x < 0),
141
+ torch.logical_or(y > self.dvs_w - 1, y < 0))
142
+ x[invalid_idx] = 0
143
+ y[invalid_idx] = 0
144
+ predict_fired[i_motion, 0, x, y] = 1
145
+
146
+ # track_post = track_post.repeat(self.motion_pattern_num, 1, 1, 1)
147
+ # grid = F.affine_grid(self.warp_matrix, track_post.shape)
148
+ # predict_fired = F.grid_sample(track_post, grid, padding_mode='zeros', align_corners=True)
149
+ # # invalid_index = torch.where(torch.logical_and(predict_fired != 0, predict_fired != 1))
150
+ # invalid_index = torch.where(predict_fired>0)
151
+ # predict_fired[invalid_index] = 1
152
+ # predict_fired[torch.where(predict_fired<1)] = 0
153
+
154
+ track_pre_exp = torch.unsqueeze(self.track_pre, 0).repeat(self.motion_pattern_num, 1, 1)
155
+ track_pre_exp = torch.unsqueeze(track_pre_exp, 1)
156
+
157
+ # STDP update the motion weight
158
+ dw_ltd = torch.zeros(self.motion_pattern_num, 1, self.dvs_h, self.dvs_w).to(self.device)
159
+ dw_ltp = torch.zeros(self.motion_pattern_num, 1, self.dvs_h, self.dvs_w).to(self.device)
160
+
161
+ tmp_bool = torch.eq(predict_fired, track_pre_exp)
162
+ index = torch.where(torch.logical_and(tmp_bool, track_pre_exp == 1))
163
+ if len(index[0]) != 0:
164
+ dw_ltp[index] = 1
165
+
166
+ index = torch.where(torch.logical_and(~tmp_bool, predict_fired == 1))
167
+ if len(index[0]) != 0:
168
+ dw_ltd[index] = 2
169
+
170
+ dw_ltp = self.pool_kernel(dw_ltp)
171
+ dw_ltd = self.pool_kernel(dw_ltd)
172
+
173
+ # dw_ltp = self.gaussian_kernel(dw_ltp)
174
+ # dw_ltd = self.gaussian_kernel(dw_ltd)
175
+
176
+ # dw = dw_ltp - dw_ltd
177
+ # dw = self.gaussian_kernel(dw_ltp - dw_ltd)
178
+ # tmp_pool[torch.where(tmp_pool == 0)] = 1
179
+ dw = torch.div((dw_ltp - dw_ltd), tmp_pool)
180
+ # dw = dw / tmp_pool
181
+ # dw = dw_ltp - dw_ltd
182
+ self.motion_weight += self.learning_rate * dw.detach().clone()
183
+
184
+ max_weight, _ = torch.max(self.motion_weight, dim=0)
185
+ min_weight, _ = torch.min(self.motion_weight, dim=0)
186
+
187
+ for iMotion in range(self.motion_pattern_num):
188
+ tmp_weight = self.motion_weight[iMotion, :, :, :].detach()
189
+ tmp_weight = (tmp_weight - min_weight) / (max_weight - min_weight)
190
+ self.motion_weight[iMotion, :, :, :] = tmp_weight.detach()
191
+
192
+ # self.motion_weight.data = F.normalize(self.motion_weight, p=2, dim=0)
193
+ self.motion_weight[torch.isnan(self.motion_weight)] = 0
194
+ # self.motion_weight[torch.isinf(self.motion_weight)] = 0
195
+ self.track_pre = spikes.detach().clone()
196
+
197
+ del track_post, tmp_pool, predict_fired, track_pre_exp, tmp_bool, dw
198
+ del tmp_weight, max_weight, min_weight, spikes
199
+ del dw_ltd, dw_ltp
200
+ torch.cuda.empty_cache()
201
+
202
+ def local_wta(self, spikes, timestamp, visualize=False):
203
+ input_spike = torch.reshape(spikes, (1, 1, self.dvs_h, self.dvs_w))
204
+
205
+ motion_vector_layer1 = torch.zeros(self.dvs_h, self.dvs_w, 2, dtype=torch.float32).to(self.device)
206
+ max_w, max_wid = torch.max(self.motion_weight, dim=0)
207
+ max_wid = torch.squeeze(max_wid)
208
+ speedId = (max_wid % self.num_speed).detach()
209
+ oriId = (torch.floor(max_wid / self.num_speed)).detach()
210
+
211
+ tmp_weight = self.motion_weight.permute(2, 3, 1, 0)
212
+ # change the dimension of matrix from (ori_num, speed_num, height, width) to (h,w, speed_num, ori_num)
213
+ tmp_weight = torch.reshape(tmp_weight, [self.dvs_h, self.dvs_w, self.num_ori, self.num_speed])
214
+ tmp_weight = tmp_weight.permute(0, 1, 3, 2)
215
+ tmp_weight_x = torch.matmul(tmp_weight, self.ori_x)
216
+ tmp_weight_y = torch.matmul(tmp_weight, self.ori_y)
217
+ # tmp_weight_x = torch.reshape(torch.mm(tmp_weight, self.ori_x), [self.dvs_h, self.dvs_w, self.num_speed])
218
+ # tmp_weight_y = torch.reshape(torch.mm(tmp_weight, self.ori_y), [self.dvs_h, self.dvs_w, self.num_speed])
219
+
220
+ max_w = torch.squeeze(max_w)
221
+ fired_spk_index2d = torch.where(torch.logical_and(spikes != 0, max_w > 0))
222
+
223
+ # speedId = speedId[fired_spk_index2d].cpu().numpy()
224
+ # oriId = oriId[fired_spk_index2d].int().cpu().numpy()
225
+ #
226
+ # dx = -1 * self.ori[oriId, 1] * self.speed[speedId]
227
+ # dy = -1 * self.ori[oriId, 0] * self.speed[speedId]
228
+
229
+ tmp_weight_x = torch.mean(tmp_weight_x, dim=2)
230
+ tmp_weight_y = torch.mean(tmp_weight_y, dim=2)
231
+ tmp_weight_x = torch.squeeze(tmp_weight_x)
232
+ tmp_weight_y = torch.squeeze(tmp_weight_y)
233
+
234
+ dx = tmp_weight_x[fired_spk_index2d]
235
+ dy = tmp_weight_y[fired_spk_index2d]
236
+
237
+ motion_vector_layer1[fired_spk_index2d[0], fired_spk_index2d[1], 0] = dx
238
+ motion_vector_layer1[fired_spk_index2d[0], fired_spk_index2d[1], 1] = dy
239
+ dy_numpy = dy.cpu().numpy()
240
+ dx_numpy = dx.cpu().numpy()
241
+
242
+ # motion_vector_layer1[fired_spk_index2d[0], fired_spk_index2d[1], 0] = torch.from_numpy(dx.astype('float32')).to(self.device)
243
+ # motion_vector_layer1[fired_spk_index2d[0], fired_spk_index2d[1], 1] = torch.from_numpy(dy.astype('float32')).to(self.device)
244
+ # dx_numpy = dx
245
+ # dy_numpy = dy
246
+
247
+ rotAng = np.arctan2(-dy_numpy, dx_numpy) * 180 / np.pi + 180
248
+ rotAng[np.where(rotAng == 360)] = 0
249
+ tmp_motion = np.floor(rotAng / (360 / 8))
250
+ # tmp_motion[np.where(tmp_motion == 8)] = 0
251
+ track_voltage = torch.zeros(self.num_ori, self.dvs_h, self.dvs_w)
252
+ track_voltage[tmp_motion, fired_spk_index2d[0], fired_spk_index2d[1]] = 1
253
+
254
+ track_voltage = torch.unsqueeze(track_voltage, 1)
255
+ track_voltage = track_voltage.to(self.device)
256
+ track_voltage = torch.squeeze(self.inhb_kernel(track_voltage))
257
+ max_v, max_vid = torch.max(track_voltage, dim=0)
258
+
259
+ fired_layer2_index = torch.where(
260
+ torch.logical_and(max_v >= self.inhb_threshold, torch.logical_and(spikes != 0, max_w > 0)))
261
+ max_motion = torch.zeros(self.dvs_h, self.dvs_w, dtype=torch.int64)
262
+ max_motion_layer1 = torch.zeros(self.dvs_h, self.dvs_w, dtype=torch.int64)
263
+ max_motion_layer1 = max_motion_layer1.to(self.device)
264
+
265
+ motion_vector_max = torch.zeros(self.dvs_h, self.dvs_w, 2, dtype=torch.float32)
266
+ max_motion = max_motion.to(self.device)
267
+ motion_vector_max = motion_vector_max.to(self.device)
268
+
269
+ max_motion[fired_layer2_index] = max_vid[fired_layer2_index].detach() + 1
270
+ motion_tensor = torch.from_numpy(tmp_motion + 1).to(self.device)
271
+ max_motion_layer1[fired_spk_index2d[0], fired_spk_index2d[1]] = motion_tensor.long()
272
+ max_motion_layer1[max_motion == 0] = 0
273
+
274
+ # 1. find the difference between m1 and mc motion
275
+ tmp_vid = max_vid[fired_layer2_index].cpu().detach().numpy()
276
+ if len(tmp_vid) != 0:
277
+ motion_vector_max[fired_layer2_index] = motion_vector_layer1[fired_layer2_index].detach()
278
+ loser_pattern_index = torch.where(torch.logical_and(max_motion != 0, max_motion_layer1 != max_motion))
279
+ fired2_index_x = loser_pattern_index[0]
280
+ fired2_index_y = loser_pattern_index[1]
281
+ voltage_block = max_v[None, None, :, :]
282
+ voltage_block = F.pad(voltage_block, (self.padding_width, self.padding_width, self.padding_width, self.padding_width),
283
+ mode='constant', value=0)
284
+ voltage_block = F.unfold(voltage_block, (self.inh_size, self.inh_size))
285
+ voltage_block = voltage_block.reshape([1, self.inh_size*self.inh_size, self.dvs_h, self.dvs_w])
286
+ offset_pattern = torch.argmax(voltage_block, dim=1)
287
+ offset_pattern = torch.squeeze(offset_pattern)
288
+ offset_pattern_loser = offset_pattern[fired2_index_x, fired2_index_y]
289
+ offset_x = offset_pattern_loser / self.inh_size - self.padding_width
290
+ offset_y = torch.fmod(offset_pattern_loser, self.inh_size) - self.padding_width
291
+ offset_x = offset_x.int()
292
+ offset_y = offset_y.int()
293
+ motion_vector_max[fired2_index_x, fired2_index_y, :] = motion_vector_max[fired2_index_x + offset_x,
294
+ fired2_index_y + offset_y, :]
295
+
296
+ # for i_vector in range(len(fired2_index_x)):
297
+ #
298
+ # if fired2_index_x[i_vector] - self.padding_width < 0:
299
+ # x_begin = 0
300
+ # else:
301
+ # x_begin = fired2_index_x[i_vector] - self.padding_width
302
+ #
303
+ # if fired2_index_x[i_vector] + self.padding_width >= self.dvs_h:
304
+ # x_end = self.dvs_h
305
+ # else:
306
+ # x_end = fired2_index_x[i_vector] + self.padding_width
307
+ #
308
+ # if fired2_index_y[i_vector] - self.padding_width < 0:
309
+ # y_begin = 0
310
+ # else:
311
+ # y_begin = fired2_index_y[i_vector] - self.padding_width
312
+ #
313
+ # if fired2_index_y[i_vector] + self.padding_width >= self.dvs_w:
314
+ # y_end = self.dvs_w
315
+ # else:
316
+ # y_end = fired2_index_y[i_vector] + self.padding_width
317
+ #
318
+ # winner_motion = max_motion[fired2_index_x[i_vector], fired2_index_y[i_vector]]
319
+ # motion_id_block = max_motion_layer1[x_begin:x_end, y_begin:y_end]
320
+ # motion_block = motion_vector_max[x_begin:x_end, y_begin:y_end, :]
321
+ # motion_voltage_block = max_v[x_begin:x_end, y_begin:y_end]
322
+ # winner_id = torch.where(motion_id_block == winner_motion)
323
+ # if len(winner_id[0]) > 0:
324
+ # winner_mv = motion_block[winner_id[0], winner_id[1], :]
325
+ # winner_voltage = motion_voltage_block[winner_id[0], winner_id[1]]
326
+ # motion_vector_max[fired2_index_x[i_vector], fired2_index_y[i_vector], :] = \
327
+ # winner_mv[torch.argmax(winner_voltage), :]
328
+
329
+ # 2. replace the loser motion pattern
330
+
331
+ if visualize is True:
332
+ Image_layer1 = torch.zeros(3, self.dvs_h, self.dvs_w).to(self.device)
333
+ Image_layer1[:, fired_spk_index2d[0], fired_spk_index2d[1]] = self.cc_motion[:, tmp_motion]
334
+
335
+ Image_layer2 = torch.zeros(3, self.dvs_h, self.dvs_w).to(self.device)
336
+ Image_layer2[:, fired_layer2_index[0], fired_layer2_index[1]] = self.cc_motion[:, tmp_vid]
337
+
338
+ self.logger.add_image('motion_estimation/M1 estimation', Image_layer1, timestamp)
339
+ self.logger.add_image('motion_estimation/MC estimation', Image_layer2, timestamp)
340
+
341
+ # track_voltage.to(self.device_cpu)
342
+
343
+ del input_spike, fired_spk_index2d, fired_layer2_index
344
+ del track_voltage, dx, dy
345
+ torch.cuda.empty_cache()
346
+
347
+ return max_motion, motion_vector_max, motion_vector_layer1
snnTracker/spkProc/tracking/snn_tracker.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2023/7/16 20:23
3
+ # @Author : Yajing Zheng
4
+ # @Email: [email protected]
5
+ # @File : snn_tracker.py
6
+ import os, sys
7
+ sys.path.append('../..')
8
+ import time
9
+ import numpy as np
10
+ import torch
11
+
12
+ from spkProc.filters.stp_filters_torch import STPFilter
13
+ # from filters import stpFilter
14
+ from spkProc.detection.attention_select import SaccadeInput
15
+ from spkProc.motion.motion_detection import motion_estimation
16
+ from spkProc.detection.stdp_clustering import stdp_cluster
17
+ from utils import NumpyEncoder
18
+ from collections import namedtuple
19
+ import json
20
+ import cv2
21
+ from tqdm import tqdm
22
+
23
+ trajectories = namedtuple('trajectories', ['id', 'x', 'y', 't', 'color'])
24
+
25
+ class SNNTracker:
26
+
27
+ def __init__(self, spike_h, spike_w, device, attention_size=20, diff_time=1, **STPargs):
28
+ self.spike_h = spike_h
29
+ self.spike_w = spike_w
30
+ self.device = device
31
+
32
+ # self.stp_filter = STPFilter(spike_h, spike_w, device)
33
+ if STPargs is not None:
34
+ self.stp_filter = STPFilter(spike_h, spike_w, device, diff_time, **STPargs)
35
+ else:
36
+ self.stp_filter = STPFilter(spike_h, spike_w, device, diff_time)
37
+ # self.stp_filter = stpFilter()
38
+ self.attention_size = attention_size
39
+ self.object_detection = SaccadeInput(spike_h, spike_w, box_size=self.attention_size, device=device)
40
+ from tensorboardX import SummaryWriter
41
+ logger = SummaryWriter(log_dir='data/log_pkuvidar')
42
+ self.motion_estimator = motion_estimation(spike_h, spike_w, device, logger=logger)
43
+ # gpu_tracker.track() # run function between the code line where uses GPU
44
+
45
+ self.object_cluster = stdp_cluster(spike_h, spike_w, box_size=self.attention_size, device=device)
46
+
47
+ # self.timestamps = spikes.shape[0]
48
+ # self.filterd_spikes = np.zeros([self.timestamps, self.spike_h, self.spike_w], np.uint8)
49
+ self.calibration_time = 150
50
+ self.timestamps = 0
51
+ self.trajectories = {}
52
+ self.filterd_spikes = []
53
+
54
+ def calibrate_motion(self, spikes, calibration_time=None):
55
+
56
+ if calibration_time is None:
57
+ calibration_time = self.calibration_time
58
+ else:
59
+ self.calibration_time = calibration_time
60
+
61
+ print('begin calibrate..')
62
+ for t in range(calibration_time):
63
+ input_spk = torch.from_numpy(spikes[t, :, :]).to(self.device)
64
+ self.stp_filter.update_dynamics(t, input_spk)
65
+ self.timestamps += 1
66
+
67
+ def get_results(self, spikes, res_filepath, mov_writer=None, save_video=False):
68
+
69
+ result_file = open(res_filepath, 'a+')
70
+
71
+ timestamps = spikes.shape[0]
72
+ total_time = 0
73
+ predict_kwargs = {'spike_h': self.spike_h, 'spike_w': self.spike_w, 'device': self.device}
74
+
75
+ for t in tqdm(range(timestamps), desc=f'Saving tracking results to {str(result_file)}'):
76
+ try:
77
+ input_spk = torch.from_numpy(spikes[t, :, :]).to(self.device)
78
+ self.stp_filter.update_dynamics(self.timestamps, input_spk)
79
+
80
+ self.stp_filter.local_connect(self.stp_filter.filter_spk)
81
+ # self.filterd_spikes[t, :, :] = self.stp_filter.lif_spk.cpu().detach().numpy()
82
+
83
+ self.object_detection.update_dnf(self.stp_filter.lif_spk)
84
+ attentionBox, attentionInput = self.object_detection.get_attention_location(self.stp_filter.lif_spk)
85
+ # attentionInput = attentionInput.to(self.device)
86
+ num_box = attentionBox.shape[0]
87
+ self.motion_estimator.stdp_tracking(self.stp_filter.lif_spk)
88
+
89
+ motion_id, motion_vector, _ = self.motion_estimator.local_wta(self.stp_filter.lif_spk, self.timestamps, visualize=True)
90
+ # gpu_tracker.track() # run function between the code line where uses GPU
91
+
92
+ predict_fire, sw, bw = self.object_cluster.update_weight(attentionInput)
93
+
94
+ predict_object = self.object_cluster.detect_object(predict_fire, attentionBox, motion_id, motion_vector, **predict_kwargs)
95
+
96
+ # visualize_weights(sw, 'before update tracks', t)
97
+
98
+ sw, bw = self.object_cluster.update_tracks(predict_object, sw, bw, self.timestamps)
99
+
100
+ self.object_cluster.synaptic_weight = sw.detach().clone()
101
+ self.object_cluster.bias_weight = bw.detach().clone()
102
+
103
+ dets = torch.zeros((num_box, 6), dtype=torch.int)
104
+ for i_box, bbox in enumerate(attentionBox):
105
+ dets[i_box, :] = torch.tensor([bbox[0], bbox[1], bbox[2], bbox[3], 1, 1])
106
+
107
+ track_ids = []
108
+ if save_video:
109
+ track_frame = self.stp_filter.lif_spk.cpu().numpy()
110
+ track_frame = (track_frame * 255).astype(np.uint8)
111
+ # track_frame = np.transpose(track_frame, (1, 2, 0))
112
+
113
+ # track_frame = np.tile(track_frame, (3, 1, 1))
114
+ # track_frame = np.squeeze(track_frame)
115
+ track_frame = cv2.cvtColor(track_frame, cv2.COLOR_GRAY2BGR)
116
+
117
+ for i_box in range(attentionBox.shape[0]):
118
+ tmp_box = attentionBox[i_box, :]
119
+ cv2.rectangle(track_frame, (int(tmp_box[1]), int(tmp_box[0])), (int(tmp_box[3]), int(tmp_box[2])),
120
+ (int(0), int(0), int(255)), 2)
121
+
122
+ for i_box in range(self.object_cluster.K2):
123
+ if self.object_cluster.tracks[i_box].visible == 1:
124
+ tmp_box = self.object_cluster.tracks[i_box].bbox.numpy()
125
+ pred_box = self.object_cluster.tracks[i_box].predbox.numpy()
126
+ id = self.object_cluster.tracks[i_box].id
127
+ color = self.object_cluster.tracks[i_box].color
128
+
129
+ # update the trajectories
130
+ mid_y = (tmp_box[0, 0] + tmp_box[0, 2]) / 2 # height
131
+ mid_x = (tmp_box[0, 1] + tmp_box[0, 3]) / 2 # width
132
+ box_w = int(tmp_box[0, 3] - tmp_box[0, 1])
133
+ box_h = int(tmp_box[0,2] - tmp_box[0, 0])
134
+ print('%d,%d,%.2f,%.2f,%.2f,%.2f,1,-1,-1,-1' % (
135
+ self.timestamps, id, tmp_box[0, 1], tmp_box[0, 0], box_w, box_h), file=result_file)
136
+
137
+ if id not in self.trajectories:
138
+ self.trajectories[id] = trajectories(int(id), [], [], [], 255 * np.random.rand(1, 3))
139
+ self.trajectories[id].x.append(mid_x)
140
+ self.trajectories[id].y.append(mid_y)
141
+ self.trajectories[id].t.append(self.timestamps)
142
+
143
+ else:
144
+ self.trajectories[id].x.append(mid_x)
145
+ self.trajectories[id].y.append(mid_y)
146
+ self.trajectories[id].t.append(self.timestamps)
147
+ # the detection results
148
+
149
+ if save_video:
150
+ cv2.rectangle(track_frame, (int(tmp_box[0, 1]), int(tmp_box[0, 0])),
151
+ (int(tmp_box[0, 3]), int(tmp_box[0, 2])),
152
+ (int(color[0, 0]), int(color[0, 1]), int(color[0, 2])), 2)
153
+
154
+ # # the predicted results
155
+ # cv2.rectangle(track_frame, (int(pred_box[0, 1]), int(pred_box[0, 0])),
156
+ # (int(pred_box[0, 3]), int(pred_box[0, 2])), (int(0), int(0), int(255)), 2)
157
+
158
+ # the label box
159
+ cv2.rectangle(track_frame, (int(tmp_box[0, 1]), int(tmp_box[0, 0] - 35)),
160
+ (int(tmp_box[0, 1] + 60), int(tmp_box[0, 0])),
161
+ (int(color[0, 0]), int(color[0, 1]), int(color[0, 2])), -1)
162
+ if self.object_cluster.tracks[i_box].unvisible_count > 0:
163
+ show_text = 'predict' + str(id)
164
+ else:
165
+ show_text = 'object' + str(id)
166
+ cv2.putText(track_frame, show_text, (int(tmp_box[0, 1]), int(tmp_box[0, 0] - 10)),
167
+ cv2.FONT_HERSHEY_SIMPLEX,
168
+ 1, (255, 255, 255), 2)
169
+
170
+ if save_video:
171
+ cv2.putText(track_frame, str(int(self.timestamps)),
172
+ (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 255), 2)
173
+ mov_writer.write(track_frame)
174
+ self.timestamps += 1
175
+
176
+ except RuntimeError as exception:
177
+ if "out of memory" in str(exception):
178
+ print('WARNING: out of memory')
179
+ if hasattr(torch.cuda, 'empty_cache'):
180
+ torch.cuda.empty_cache()
181
+ else:
182
+ raise exception
183
+
184
+ print('Total tracking took: %.3f seconds for %d timestamps spikes' %
185
+ (total_time, self.timestamps - self.calibration_time))
186
+
187
+ # if save_video:
188
+ # mov_writer.release()
189
+ # cv2.destroyAllWindows()
190
+
191
+ result_file.close()
192
+
193
+ def save_trajectory(self, results_dir, data_name):
194
+ trajectories_filename = os.path.join(results_dir, data_name + '_py.json')
195
+ mat_trajectories_filename = 'results/' + data_name + '.json'
196
+ track_box_filename = 'results/' + data_name + '_bbox.json'
197
+
198
+ if os.path.exists(trajectories_filename):
199
+ os.remove(trajectories_filename)
200
+
201
+ if os.path.exists(mat_trajectories_filename):
202
+ os.remove(mat_trajectories_filename)
203
+
204
+ if os.path.exists(track_box_filename):
205
+ os.remove(track_box_filename)
206
+
207
+ for i_traj in range(self.object_cluster.K2):
208
+ tmp_traj = self.object_cluster.trajectories[i_traj]
209
+ tmp_bbox = self.object_cluster.tracks_bbox[i_traj]
210
+
211
+ traj_json_string = json.dumps(tmp_traj._asdict(), cls=NumpyEncoder)
212
+ bbox_json_string = json.dumps(tmp_bbox._asdict(), cls=NumpyEncoder)
213
+
214
+ with open(mat_trajectories_filename, 'a+') as f:
215
+ f.write(traj_json_string)
216
+
217
+ with open(track_box_filename, 'a+') as f:
218
+ f.write(bbox_json_string)
219
+
220
+ num_len = len(self.trajectories)
221
+ for i_traj in self.trajectories:
222
+ traj_json_string = json.dumps(self.trajectories[i_traj]._asdict(), cls=NumpyEncoder)
223
+
224
+ with open(trajectories_filename, 'a+') as f:
225
+ f.write(traj_json_string)
226
+
227
+ f.write('\n')
snnTracker/test_motion_detection.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from spkProc.tracking.snn_tracker import SNNTracker
4
+ import matplotlib.pyplot as plt
5
+ import cv2
6
+
7
+
8
+ def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True):
9
+ '''
10
+ output: <class 'numpy.ndarray'> (frame_cnt, height, width) {0,1} float32
11
+ '''
12
+ array = np.fromfile(filename, dtype=np.uint8)
13
+
14
+ len_per_frame = height * width // 8
15
+ framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame
16
+
17
+ spikes = []
18
+ for i in range(framecnt):
19
+ compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame]
20
+ blist = []
21
+ for b in range(8):
22
+ blist.append(np.right_shift(np.bitwise_and(
23
+ compr_frame, np.left_shift(1, b)), b))
24
+
25
+ frame_ = np.stack(blist).transpose()
26
+ frame_ = frame_.reshape((height, width), order='C')
27
+ if reverse_spike:
28
+ frame_ = np.flipud(frame_)
29
+ spikes.append(frame_)
30
+
31
+ return np.array(spikes).astype(np.float32)
32
+
33
+
34
+
35
+ def detect_motion(spikes, calibration_frames=200, device=None):
36
+ """
37
+ 使用SNN进行运动目标检测
38
+ Args:
39
+ spikes: shape为[frames, height, width]的脉冲数据
40
+ calibration_frames: 用于校准的帧数
41
+ device: 运行设备(CPU/GPU)
42
+ Returns:
43
+ motion_mask: 第calibration_frames帧的运动目标掩码
44
+ """
45
+ if device is None:
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+
48
+ spike_h, spike_w = spikes.shape[1:]
49
+
50
+ # 初始化SNN跟踪器
51
+ spike_tracker = SNNTracker(spike_h, spike_w, device, attention_size=15)
52
+
53
+ # 使用前calibration_frames帧进行校准
54
+ calibration_spikes = spikes[:calibration_frames]
55
+ spike_tracker.calibrate_motion(calibration_spikes, calibration_frames)
56
+
57
+ # 获取第calibration_frames帧的运动检测结果
58
+ target_frame = spikes[calibration_frames]
59
+ target_frame = torch.from_numpy(target_frame).to(device)
60
+ # target_frame = target_frame.reshape(1, 1, spike_h, spike_w)
61
+
62
+ # 获取运动检测结果
63
+ motion_id, motion_vector, _ = spike_tracker.motion_estimator.local_wta(target_frame, calibration_frames)
64
+
65
+ # 生成运动掩码
66
+ motion_mask = (motion_id > 0).cpu().numpy()
67
+
68
+ return motion_mask
69
+
70
+ def spikes_to_tfi(spk_seq):
71
+ n, h, w = spk_seq.shape
72
+ last_index = np.zeros((1, h, w))
73
+ cur_index = np.zeros((1, h, w))
74
+ c_frames = np.zeros_like(spk_seq).astype(np.float64)
75
+ for i in range(n - 1):
76
+ last_index = cur_index
77
+ cur_index = spk_seq[i+1,:,:] * (i + 1) + (1 - spk_seq[i+1,:,:]) * last_index
78
+ c_frames[i,:,:] = cur_index - last_index
79
+ last_frame = c_frames[n-1:,:]
80
+ last_frame[last_frame==0] = n
81
+ c_frames[n-1,:,:] = last_frame
82
+ last_interval = n * np.ones((1, h, w))
83
+ for i in range(n - 2, -1, -1):
84
+ last_interval = spk_seq[i+1,:,:] * c_frames[i,:,:] + (1 - spk_seq[i+1,:,:]) * last_interval
85
+ tmp_frame = np.expand_dims(c_frames[i,:,:], 0)
86
+ tmp_frame[tmp_frame==0] = last_interval[tmp_frame==0]
87
+ c_frames[i] = tmp_frame
88
+ return 1.0 / c_frames
89
+
90
+ def detect_object(spikes, calibration_frames=200, device=None):
91
+ """
92
+ 使用SNN进行目标检测
93
+ Args:
94
+ spikes: shape为[frames, height, width]的脉冲数据
95
+ calibration_frames: 用于校准的帧数
96
+ device: 运行设备(CPU/GPU)
97
+ Returns:
98
+ object_mask: 第calibration_frames帧的目标掩码
99
+ """
100
+ if device is None:
101
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
+
103
+ spike_h, spike_w = spikes.shape[1:]
104
+
105
+ # 初始化SNN跟踪器
106
+ spike_tracker = SNNTracker(spike_h, spike_w, device, attention_size=15)
107
+ spike_tracker.object_cluster.K2 = 4
108
+
109
+ # 使用前calibration_frames帧进行校准
110
+ calibration_spikes = spikes[:calibration_frames]
111
+ spike_tracker.calibrate_motion(calibration_spikes, calibration_frames)
112
+
113
+ # 获取第calibration_frames帧的目标检测结果
114
+ target_frame = spikes[calibration_frames: calibration_frames + 200]
115
+ print(target_frame.shape)
116
+ # target_frame = target_frame.reshape(1, 1, spike_h, spike_w)
117
+
118
+ # 获取目标检测结果
119
+ save_filename = "testtest.avi"
120
+ mov = cv2.VideoWriter(save_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (400, 250))
121
+ spike_tracker.get_results(target_frame, save_filename, mov, save_video=True)
122
+
123
+ mov.release()
124
+ cv2.destroyAllWindows()
125
+ return 0
126
+
127
+
128
+ if __name__ == "__main__":
129
+ height = 250
130
+ width = 400
131
+ spikes = load_vidar_dat("0.dat", width=width, height=height)
132
+ for n in range(1,10):
133
+ tmp_spikes = load_vidar_dat(f"{n}.dat", width=width, height=height)
134
+ spikes = np.concatenate((spikes, tmp_spikes), axis=0)
135
+ print(spikes.shape)
136
+
137
+ spikes = spikes[::10]
138
+
139
+ motion_mask = detect_object(spikes, calibration_frames=200)
140
+
141
+
142
+
143
+ tfi = spikes_to_tfi(spikes)
144
+ # 保存重建的视频
145
+ save_recon_filename = "tfi.avi"
146
+ recon_mov = cv2.VideoWriter(save_recon_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (width, height))
147
+
148
+ for frame in tfi:
149
+ frame_norm = (frame * 255).astype(np.uint8)
150
+ frame_rgb = cv2.cvtColor(frame_norm, cv2.COLOR_GRAY2BGR)
151
+ recon_mov.write(frame_rgb)
152
+
153
+ recon_mov.release()
154
+
155
+
156
+
157
+ # 检测运动目标
158
+ # motion_mask = detect_motion(spikes, calibration_frames=200)
159
+ # print(f"Motion mask shape: {motion_mask.shape}")
160
+ # print(f"Number of motion pixels: {motion_mask.sum()}")
161
+
162
+ # 可视化运动目标检测结果
163
+ # plt.figure(figsize=(10, 5))
164
+ # plt.subplot(1, 2, 1)
165
+ # plt.imshow(spikes[200], cmap='gray')
166
+ # plt.title("Input frame")
167
+ # plt.axis('off')
168
+ # plt.subplot(1, 2, 2)
169
+ # plt.imshow(motion_mask, cmap='gray')
170
+ # plt.title("Motion mask")
171
+ # plt.axis('off')
172
+ # plt.show()
173
+
174
+
175
+ # 计算原始脉冲图和运动掩码之间的差异
176
+ # spike_frame = spikes[200] # 获取第200帧脉冲图
177
+
178
+ # # 计算差异指标
179
+ # pixel_diff = np.logical_xor(spike_frame > 0, motion_mask).sum()
180
+ # total_pixels = height * width
181
+ # diff_ratio = pixel_diff / total_pixels
182
+
183
+ # print("\n运动检测结果分析:")
184
+ # print(f"原始脉冲图中的活跃像素数: {(spike_frame > 0).sum()}")
185
+ # print(f"运动掩码中的运动像素数: {motion_mask.sum()}")
186
+ # print(f"不一致的像素数: {pixel_diff}")
187
+ # print(f"像素差异比例: {diff_ratio:.2%}")
188
+
189
+ # # 可视化差异
190
+ # plt.figure(figsize=(10, 5))
191
+ # plt.subplot(1, 2, 1)
192
+ # plt.imshow(np.logical_xor(spike_frame > 0, motion_mask), cmap='gray')
193
+ # plt.title("Difference map (white indicates inconsistency)")
194
+ # plt.axis('off')
195
+
196
+ # plt.subplot(1, 2, 2)
197
+ # plt.imshow(spike_frame > 0, cmap='gray', alpha=0.5)
198
+ # plt.imshow(motion_mask, cmap='Reds', alpha=0.5)
199
+ # plt.title("Overlay (Red: Motion mask, Gray: Original spikes)")
200
+ # plt.axis('off')
201
+ # plt.show()
snnTracker/test_snntracker copy.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/12/05 20:17
3
+ # @Author : Yajing Zheng
4
+ # @Email: [email protected]
5
+ # @File : test_snntracker.py
6
+ import os, sys
7
+ sys.path.append("..")
8
+ import path
9
+
10
+ import numpy as np
11
+ from spkData.load_dat import data_parameter_dict, SpikeStream
12
+ from pprint import pprint
13
+ import torch
14
+ from spkProc.tracking.snn_tracker import SNNTracker
15
+ from utils import vis_trajectory
16
+ from visualization.get_video import obtain_mot_video
17
+ import cv2
18
+ # from tracking_mot import TrackingMetrics
19
+
20
+ from visualization.get_video import obtain_detection_video
21
+
22
+ # change the path to where you put the datasets
23
+ test_scene = "0"
24
+ # data_filename = 'motVidarReal2020/rotTrans'
25
+ data_filename = test_scene
26
+ label_type = 'tracking'
27
+ para_dict = data_parameter_dict(data_filename, label_type)
28
+ pprint(para_dict)
29
+ vidarSpikes = SpikeStream(**para_dict)
30
+
31
+ # block_len = 2000
32
+ # spikes = vidarSpikes.get_block_spikes(begin_idx=0, block_len=block_len)
33
+ spikes = vidarSpikes.get_spike_matrix()
34
+ pprint(spikes.shape)
35
+
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+
38
+ calibration_time = 150
39
+ filename = path.split_path_into_pieces(data_filename)
40
+ result_filename = filename[-1] + '_snn.txt'
41
+ if not os.path.exists('results'):
42
+ os.makedirs('results')
43
+ tracking_file = os.path.join('results', result_filename)
44
+ if os.path.exists(tracking_file):
45
+ os.remove(tracking_file)
46
+
47
+ # stp_params = {'filterThr': 0.12, # filter threshold
48
+ # 'voltageMin': -10,
49
+ # 'lifThr': 3}
50
+ spike_tracker = SNNTracker(para_dict.get('spike_h'), para_dict.get('spike_w'), device, attention_size=15)
51
+ spike_tracker.object_cluster.K2 = 4
52
+ # total_spikes = spikes
53
+
54
+ # using stp filter to filter out static spikes
55
+ spike_tracker.calibrate_motion(spikes, calibration_time)
56
+ # start tracking
57
+ track_videoName = tracking_file.replace('txt', 'avi')
58
+ mov = cv2.VideoWriter(track_videoName, cv2.VideoWriter_fourcc(*'MJPG'), 30, (para_dict.get('spike_w'), para_dict.get('spike_h')))
59
+ spike_tracker.get_results(spikes[calibration_time:], tracking_file, mov, save_video=True)
60
+
61
+ data_name = test_scene
62
+ trajectories_filename = os.path.join('results', data_name + '_py.json')
63
+ visTraj_filename = os.path.join('results', data_name + '.png')
64
+
65
+ spike_tracker.save_trajectory('results', data_name)
66
+ vis_trajectory(trajectories_filename, visTraj_filename, **para_dict)
67
+ # measure the multi-object tracking performance
68
+ # metrics = TrackingMetrics(tracking_file, **para_dict)
69
+ # metrics.get_results()
70
+ #
71
+ # block_len = total_spikes.shape[0]
72
+ mov.release()
73
+ cv2.destroyAllWindows()
74
+ # # visualize the tracking results to a video
75
+ # video_filename = os.path.join('results', filename[-1] + '_mot.avi')
76
+ # obtain_mot_video(spike_tracker.filterd_spikes, video_filename, tracking_file, **para_dict)
77
+ # obtain_detection_video(total_spikes, video_filename, tracking_file, evaluate_seq_len=evaluate_seq_len, **para_dict)
snnTracker/test_snntracker.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/12/05 20:17
3
+ # @Author : Yajing Zheng
4
+ # @Email: [email protected]
5
+ # @File : test_snntracker.py
6
+ import os, sys
7
+ sys.path.append("..")
8
+ import path
9
+
10
+ import numpy as np
11
+ from spkData.load_dat import data_parameter_dict, SpikeStream
12
+ from pprint import pprint
13
+ import torch
14
+ from spkProc.tracking.snn_tracker import SNNTracker
15
+ from utils import vis_trajectory
16
+ from visualization.get_video import obtain_mot_video
17
+ import cv2
18
+ from tracking_mot import TrackingMetrics
19
+
20
+ from visualization.get_video import obtain_detection_video
21
+
22
+ # change the path to where you put the datasets
23
+ test_scene = ['spike59', 'rotTrans', 'cplCam', 'cpl1', 'badminton', 'ball']
24
+ # data_filename = 'motVidarReal2020/rotTrans'
25
+ scene_idx = 2
26
+ data_filename = 'motVidarReal2020/' + test_scene[scene_idx]
27
+ label_type = 'tracking'
28
+ para_dict = data_parameter_dict(data_filename, label_type)
29
+ pprint(para_dict)
30
+ vidarSpikes = SpikeStream(**para_dict)
31
+
32
+ # block_len = 2000
33
+ # spikes = vidarSpikes.get_block_spikes(begin_idx=0, block_len=block_len)
34
+ spikes = vidarSpikes.get_spike_matrix()
35
+ pprint(spikes.shape)
36
+
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ calibration_time = 150
40
+ filename = path.split_path_into_pieces(data_filename)
41
+ result_filename = filename[-1] + '_snn.txt'
42
+ if not os.path.exists('results'):
43
+ os.makedirs('results')
44
+ tracking_file = os.path.join('results', result_filename)
45
+ if os.path.exists(tracking_file):
46
+ os.remove(tracking_file)
47
+
48
+ # stp_params = {'filterThr': 0.12, # filter threshold
49
+ # 'voltageMin': -10,
50
+ # 'lifThr': 3}
51
+ spike_tracker = SNNTracker(para_dict.get('spike_h'), para_dict.get('spike_w'), device, attention_size=15)
52
+ spike_tracker.object_cluster.K2 = 4
53
+ # total_spikes = spikes
54
+
55
+ # using stp filter to filter out static spikes
56
+ spike_tracker.calibrate_motion(spikes, calibration_time)
57
+ # start tracking
58
+ track_videoName = tracking_file.replace('txt', 'avi')
59
+ mov = cv2.VideoWriter(track_videoName, cv2.VideoWriter_fourcc(*'MJPG'), 30, (para_dict.get('spike_w'), para_dict.get('spike_h')))
60
+ spike_tracker.get_results(spikes[calibration_time:], tracking_file, mov, save_video=True)
61
+
62
+ data_name = test_scene[scene_idx]
63
+ trajectories_filename = os.path.join('results', data_name + '_py.json')
64
+ visTraj_filename = os.path.join('results', data_name + '.png')
65
+
66
+ spike_tracker.save_trajectory('results', data_name)
67
+ vis_trajectory(trajectories_filename, visTraj_filename, **para_dict)
68
+ # measure the multi-object tracking performance
69
+ # metrics = TrackingMetrics(tracking_file, **para_dict)
70
+ # metrics.get_results()
71
+ #
72
+ # block_len = total_spikes.shape[0]
73
+ mov.release()
74
+ cv2.destroyAllWindows()
75
+ # # visualize the tracking results to a video
76
+ # video_filename = os.path.join('results', filename[-1] + '_mot.avi')
77
+ # obtain_mot_video(spike_tracker.filterd_spikes, video_filename, tracking_file, **para_dict)
78
+ # obtain_detection_video(total_spikes, video_filename, tracking_file, evaluate_seq_len=evaluate_seq_len, **para_dict)
snnTracker/utils.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import torch
4
+ import threading
5
+ import cv2
6
+ import json
7
+
8
+ # import matplotlib
9
+ # matplotlib.use('TkAgg')
10
+ import matplotlib.pyplot as plt
11
+ from mpl_toolkits.mplot3d import Axes3D
12
+ from matplotlib.pyplot import MultipleLocator
13
+
14
+
15
+ class dataReader(threading.Thread):
16
+ def __init__(self, file_reader, device, q, is_dat=True, is_npy=False, filedir=None):
17
+ super(dataReader, self).__init__()
18
+ self.file_reader = file_reader
19
+ self.device = device
20
+ self.q = q
21
+ self.is_dat = is_dat
22
+ self.is_npy = is_npy
23
+ self.filedir = filedir
24
+ self.stream = torch.cuda.Stream()
25
+
26
+ def run(self):
27
+ with torch.cuda.stream(self.stream):
28
+ for t in range(tnum):
29
+ if self.is_dat:
30
+ ibuffer = self.file_reader.read(int(ivs_w * ivs_h / 8))
31
+ a = bin(int.from_bytes(ibuffer, byteorder=sys.byteorder))
32
+ a = a[2:].zfill(ivs_w * ivs_h)
33
+
34
+ a = list(a)
35
+ a = np.array(a, dtype=np.byte)
36
+ a = np.reshape(a, [ivs_h, ivs_w])
37
+ if ivs_h == 600:
38
+ a = np.flip(a, 0)
39
+ if ivs_h == 250:
40
+ a = np.flip(a, 1)
41
+ input_spk = torch.from_numpy(a != 0).to(device)
42
+ elif self.is_npy:
43
+ npy_filename = self.filedir + str(t + 442) + '.npy'
44
+ tmp_data = np.load(npy_filename)
45
+ superResolution_rate = tmp_data.shape[2]
46
+ for i_data in range(superResolution_rate):
47
+ tmp_spk = tmp_data[:, :, i_data]
48
+ input_spk = torch.from_numpy(tmp_spk).to(device)
49
+ self.q.put(input_spk)
50
+
51
+ else:
52
+ # img_filename = self.filedir + str(t + 4200) + '.png'
53
+ img_filename = self.filedir + 'spike_' + str(t + 1) + '.png'
54
+ # print('reading %d frames' % (t+1))
55
+ # print('reading %d frames' % (t+5000))
56
+ a = cv2.imread(img_filename)
57
+ a = cv2.cvtColor(a, cv2.COLOR_BGR2GRAY)
58
+ a = a / 255
59
+ a = np.array(a, dtype=np.byte)
60
+ input_spk = torch.from_numpy(a != 0).to(device)
61
+
62
+ self.q.put(input_spk)
63
+
64
+
65
+ # obtain 2D gaussian filter
66
+ def get_kernel(filter_size, sigma):
67
+ assert (filter_size + 1) % 2 == 0, '2D filter size must be odd number!'
68
+ g = np.zeros((filter_size, filter_size), dtype=np.float32)
69
+ half_width = int((filter_size - 1) / 2)
70
+ # center location
71
+
72
+ xc = (filter_size + 1) / 2
73
+ yc = (filter_size + 1) / 2
74
+ for i in range(-half_width, half_width + 1, 1):
75
+ for j in range(-half_width, half_width + 1, 1):
76
+ x = int(xc + i)
77
+ y = int(yc + j)
78
+ g[y - 1, x - 1] = np.exp(- (i ** 2 + j ** 2) / 2 / sigma / sigma)
79
+
80
+ g = (g - g.min()) / (g.max() - g.min())
81
+ return g
82
+
83
+
84
+ def get_transform_matrix(ori, speed):
85
+ ori_num = len(ori)
86
+ speed_num = len(speed)
87
+ transform_matrix = torch.zeros(ori_num * speed_num, 2, 3)
88
+ cnt = 0
89
+ for iOri in range(ori_num):
90
+ for iSpeed in range(speed_num):
91
+ transform_matrix[cnt, 0, 0] = 1
92
+ transform_matrix[cnt, 1, 1] = 1
93
+
94
+ transform_matrix[cnt, 0, 2] = - float(ori[iOri, 1] * speed[iSpeed] / ivs_w)
95
+ transform_matrix[cnt, 1, 2] = - float(ori[iOri, 0] * speed[iSpeed] / ivs_h)
96
+
97
+ cnt += 1
98
+
99
+ transform_matrix = transform_matrix.to(device)
100
+ return transform_matrix
101
+
102
+
103
+ def get_transform_matrix_new(ori, speed, dvs_w, dvs_h, device):
104
+ ori_num = len(ori)
105
+ speed_num = len(speed)
106
+ transform_matrix = torch.zeros(ori_num * speed_num, 2, 3)
107
+ cnt = 0
108
+ for iOri in range(ori_num):
109
+ for iSpeed in range(speed_num):
110
+ transform_matrix[cnt, 0, 0] = 1
111
+ transform_matrix[cnt, 1, 1] = 1
112
+
113
+ transform_matrix[cnt, 0, 2] = - float(ori[iOri, 1] * speed[iSpeed] / dvs_w)
114
+ transform_matrix[cnt, 1, 2] = - float(ori[iOri, 0] * speed[iSpeed] / dvs_h)
115
+
116
+ cnt += 1
117
+
118
+ transform_matrix = transform_matrix.to(device)
119
+ return transform_matrix
120
+
121
+
122
+ # monitor the inference process
123
+ def visualize_img(gray_img, tag, curT):
124
+ gray_img = gray_img.float32()
125
+ img = torch.unsqueeze(gray_img, 0)
126
+ logger.add_image(tag, img, global_step=curT)
127
+
128
+
129
+ def visualize_images(images, tag, curT):
130
+ if images.shape[0] < 1:
131
+ return
132
+ images = torch.squeeze(images)
133
+ img_num = images.shape[-1]
134
+ for iImg in range(img_num):
135
+ tmp_img = images[:, :, iImg]
136
+ tmp_img = torch.squeeze(tmp_img)
137
+ tmp_img = torch.unsqueeze(tmp_img, 0)
138
+ logger.add_image(tag + str(iImg), tmp_img, global_step=curT)
139
+
140
+
141
+ def visualize_weights(weights, tag, curT):
142
+ if weights.shape[0] < 1:
143
+ return
144
+ weights = torch.squeeze(weights)
145
+ weights_num = weights.shape[0]
146
+ input_size = weights.shape[1]
147
+ stim_size = int(np.sqrt(input_size))
148
+ for iw in range(weights_num):
149
+ tmp_w = weights[iw, :]
150
+ tmp_w = torch.squeeze(tmp_w)
151
+ tmp_w = (tmp_w - torch.min(tmp_w)) / (torch.max(tmp_w) - torch.min(tmp_w))
152
+ tmp_w = torch.reshape(tmp_w, (stim_size, stim_size))
153
+ tmp_w = torch.unsqueeze(tmp_w, 0)
154
+ logger.add_image(tag + str(iw), tmp_w, global_step=curT)
155
+
156
+
157
+ class NumpyEncoder(json.JSONEncoder):
158
+ def default(self, obj):
159
+ if isinstance(obj, np.ndarray):
160
+ return obj.tolist()
161
+ return json.JSONEncoder.default(self, obj)
162
+
163
+
164
+ def vis_trajectory(json_file, filename, **dataDict):
165
+ spike_h = dataDict.get('spike_h')
166
+ spike_w = dataDict.get('spike_w')
167
+ traj_dict = []
168
+ with open(json_file, 'r') as f:
169
+ for line in f.readlines():
170
+ traj_dict.append(json.loads(line))
171
+
172
+ num_traj = len(traj_dict)
173
+
174
+ fig = plt.figure(figsize=[10, 6])
175
+ ax = fig.add_subplot(111, projection='3d')
176
+ min_t = 1000
177
+ max_t = 0
178
+
179
+ for tmp_traj in traj_dict:
180
+ tmp_t = np.array(tmp_traj['t'])
181
+ if np.min(tmp_t) < min_t:
182
+ min_t = np.min(tmp_t)
183
+ if np.max(tmp_t) > max_t:
184
+ max_t = np.max(tmp_t)
185
+
186
+ tmp_x = spike_w - np.array(tmp_traj['x'])
187
+ tmp_y = np.array(tmp_traj['y'])
188
+ tmp_color = np.array(tmp_traj['color']) / 255.
189
+ ax.plot(tmp_t, tmp_x, tmp_y, color=tmp_color, linewidth=2, label='traj ' + str(tmp_traj['id']))
190
+
191
+ ax.legend(loc='best', bbox_to_anchor=(0.7, 0., 0.4, 0.8))
192
+ zoom = [2.2, 0.8, 0.5, 1]
193
+ ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([zoom[0], zoom[1], zoom[2], zoom[3]]))
194
+ ax.set_xlim(min_t, max_t)
195
+ ax.set_ylim(0, spike_w)
196
+ ax.set_zlim(0, spike_h)
197
+
198
+ ax.set_xlabel('time', fontsize=15)
199
+ ax.set_ylabel('width', fontsize=15)
200
+ ax.set_zlabel('height', fontsize=15)
201
+
202
+ ax.view_init(elev=16, azim=135)
203
+ ax.yaxis.set_major_locator(MultipleLocator(100))
204
+ fig.subplots_adjust(top=1., bottom=0., left=0.2, right=1.)
205
+ # fig.tight_layout()
206
+ plt.show()
207
+ plt.savefig(filename, dpi=500, transparent=True)
snnTracker/visualization/get_image.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2023/8/20 16:06
3
+ # @Author : Yajing Zheng
4
+ # @Email: [email protected]
5
+ # @File : get_image.py
6
+ import numpy as np
7
+ import matplotlib
8
+ matplotlib.use('Agg')
9
+ import matplotlib.pyplot as plt
10
+ import json
11
+ from mpl_toolkits.mplot3d import Axes3D
12
+ from matplotlib.pyplot import MultipleLocator
13
+ from matplotlib.patches import Rectangle
14
+ import torch
15
+ import copy
16
+
17
+ def get_spike_raster(data):
18
+ num_neuron, timesteps = data.shape
19
+ colors = [f'C{i}' for i in range(num_neuron)]
20
+ # set different line properties for each set of positions
21
+ # note that some overlap
22
+ lineoffsets1 = np.array(range(1, num_neuron*2+1, 2))
23
+ linelengthts1 = np.ones((num_neuron, )) * 1.5
24
+
25
+ plt.figure(figsize=(8, 6))
26
+ plt.eventplot(data, colors=colors, lineoffsets=lineoffsets1, linelengths=linelengthts1)
27
+ return plt.gcf()
28
+
29
+
30
+ def get_heatmap_handle(data, marker=None, bounding_box=None):
31
+
32
+ if torch.is_tensor(data):
33
+ data = copy.deepcopy(data.cpu().detach().numpy())
34
+
35
+ fig, ax = plt.subplots(figsize=(8, 6))
36
+ h, w = data.shape
37
+ if marker is not None:
38
+ num_points = marker.shape[1]
39
+ colors = [f'C{i}' for i in range(num_points)]
40
+ for i_point in range(num_points):
41
+ ax.plot(marker[1, i_point], h-marker[0, i_point], 'o', color=colors[i_point], markersize=10)
42
+ ax.annotate('P{}'.format(i_point), (marker[1, i_point], h-marker[0, i_point]))
43
+
44
+ if bounding_box is not None:
45
+ for i_box, bbox in enumerate(bounding_box):
46
+ ax.add_patch(Rectangle((bbox[1], bbox[0]), bbox[3]-bbox[1], bbox[2] - bbox[0],
47
+ edgecolor='red', facecolor='none', lw=2))
48
+
49
+ ax.imshow(data, cmap='Blues', interpolation='nearest')
50
+
51
+ # plt.colorbar()
52
+ plt.axis('off') # 可选,关闭坐标轴
53
+ plt.title('Heatmap')
54
+
55
+ return plt.gcf()
56
+
57
+
58
+ def get_histogram_handle(data, marker=None, bounding_box=None):
59
+
60
+ if torch.is_tensor(data):
61
+ data = copy.deepcopy(data.cpu().detach().numpy())
62
+
63
+ fig, ax = plt.subplots(figsize=(8, 6))
64
+ h, w = data.shape
65
+ ax.hist(data.reshape((-1, 1)), bins=20)
66
+
67
+ # plt.colorbar()
68
+ # plt.axis('off') # 可选,关闭坐标轴
69
+ plt.title('Heatmap')
70
+
71
+ return plt.gcf()
72
+ def vis_trajectory(box_file, json_file, filename, **dataDict):
73
+
74
+ spike_h = dataDict.get('spike_h')
75
+ spike_w = dataDict.get('spike_w')
76
+ traj_dict = []
77
+ with open(json_file, 'r') as f:
78
+ for line in f.readlines():
79
+ traj_dict.append(json.loads(line))
80
+
81
+ box_file = open(box_file, 'r')
82
+ result_lines = box_file.readlines()
83
+ num_traj = len(traj_dict)
84
+
85
+ fig = plt.figure(figsize=[10, 6])
86
+ ax = fig.add_subplot(111, projection='3d')
87
+ min_t = 1000
88
+ max_t = 0
89
+
90
+ for tmp_traj in traj_dict:
91
+ tmp_t = np.array(tmp_traj['t'])
92
+ if np.min(tmp_t) < min_t:
93
+ min_t = np.min(tmp_t)
94
+ if np.max(tmp_t) > max_t:
95
+ max_t = np.max(tmp_t)
96
+
97
+ tmp_x = spike_w - np.array(tmp_traj['x'])
98
+ tmp_y = np.array(tmp_traj['y'])
99
+ tmp_color = np.array(tmp_traj['color']) / 255.
100
+ ax.plot(tmp_t, tmp_x, tmp_y, color=tmp_color, linewidth=2, label='traj ' + str(tmp_traj['id']))
101
+
102
+ ax.legend(loc='best', bbox_to_anchor=(0.7, 0., 0.4, 0.8))
103
+ zoom = [2.2, 0.8, 0.5, 1]
104
+ ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([zoom[0], zoom[1], zoom[2], zoom[3]]))
105
+ ax.set_xlim(min_t, max_t)
106
+ ax.set_ylim(0, spike_w)
107
+ ax.set_zlim(0, spike_h)
108
+
109
+ ax.set_xlabel('time', fontsize=15)
110
+ ax.set_ylabel('width', fontsize=15)
111
+ ax.set_zlabel('height', fontsize=15)
112
+
113
+ ax.view_init(elev=16, azim=135)
114
+ # ax.view_init(elev=2, azim=27)
115
+ ax.yaxis.set_major_locator(MultipleLocator(100))
116
+ fig.subplots_adjust(top=1., bottom=0., left=0.2, right=1.)
117
+ # fig.tight_layout()
118
+ # plt.savefig(filename, dpi=500, transparent=True)
119
+ # filename = filename.replace('png', 'eps')
120
+ # plt.savefig(filename, dpi=500, transparent=True)
121
+ plt.show()
snnTracker/visualization/get_video.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2022/6/12 15:21
3
+ # @Author : Yajing Zheng
4
+ # @File : visualize.py
5
+ import cv2
6
+ import numpy as np
7
+ import matplotlib
8
+ matplotlib.use('Agg')
9
+ import matplotlib.pyplot as plt
10
+
11
+ import matplotlib.cm as cm
12
+ import matplotlib.animation as animation
13
+
14
+ def obtain_spike_video(spikes, video_filename, **dataDict):
15
+ spike_h = dataDict.get('spike_h')
16
+ spike_w = dataDict.get('spike_w')
17
+ timestamps = spikes.shape[0]
18
+
19
+ mov = cv2.VideoWriter(video_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (spike_w, spike_h))
20
+
21
+ for iSpk in range(timestamps):
22
+ tmpSpk = spikes[iSpk, :, :] * 255
23
+ tmpSpk = cv2.cvtColor(tmpSpk.astype(np.uint8), cv2.COLOR_GRAY2BGR)
24
+ mov.write(tmpSpk)
25
+
26
+ mov.release()
27
+
28
+
29
+ def obtain_reconstruction_video(images, video_filename, **dataDict):
30
+ spike_h = dataDict.get('spike_h')
31
+ spike_w = dataDict.get('spike_w')
32
+
33
+ img_num = images.shape[0]
34
+ mov = cv2.VideoWriter(video_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (spike_w, spike_h))
35
+ for iImg in range(img_num):
36
+ tmp_img = images[iImg, :, :]
37
+ tmp_img = cv2.cvtColor(tmp_img, cv2.COLOR_GRAY2BGR)
38
+ mov.write(tmp_img)
39
+
40
+ mov.release()
41
+
42
+
43
+ def obtain_mot_video(spikes, video_filename, res_filepath, **dataDict):
44
+ spike_h = dataDict.get('spike_h')
45
+ spike_w = dataDict.get('spike_w')
46
+
47
+ gt_file = dataDict.get('labeled_data_dir')
48
+ gt_boxes = {}
49
+ if gt_file is not None:
50
+ gt_f = open(gt_file, 'r')
51
+ gt_lines = gt_f.readlines()
52
+ for line in gt_lines:
53
+ gt_term = line.split(',')
54
+ time_step = gt_term[0]
55
+ box_id = gt_term[1]
56
+ x = float(gt_term[2])
57
+ y = float(gt_term[3])
58
+ w = float(gt_term[4])
59
+ h = float(gt_term[5])
60
+
61
+ if str(time_step) not in gt_boxes:
62
+ gt_boxes[str(time_step)] = []
63
+ bbox = [box_id, x, y, w, h]
64
+ gt_boxes[str(time_step)].append(bbox)
65
+
66
+ gt_f.close()
67
+
68
+ result_file = res_filepath
69
+ test_boxes = {}
70
+ result_f = open(result_file, 'r')
71
+ result_lines = result_f.readlines()
72
+ color_dict = {}
73
+
74
+ for line in result_lines:
75
+ res_box = line.split(',')
76
+ time_step = res_box[0]
77
+ track_id = res_box[1]
78
+ if track_id not in color_dict.keys():
79
+ colors = (np.random.rand(1, 3) * 255).astype(np.uint8)
80
+ color_dict[track_id] = np.squeeze(colors)
81
+
82
+ x = float(res_box[2])
83
+ y = float(res_box[3])
84
+ w = float(res_box[4])
85
+ h = float(res_box[5])
86
+
87
+ if str(time_step) not in test_boxes:
88
+ test_boxes[str(time_step)] = []
89
+
90
+ test_box = [track_id, x, y, w, h]
91
+ test_boxes[str(time_step)].append(test_box)
92
+
93
+ result_f.close()
94
+
95
+ mov = cv2.VideoWriter(video_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (spike_w, spike_h))
96
+
97
+ timestamps = spikes.shape[0]
98
+ for t in range(151, timestamps):
99
+ # for t in range(160, 1000):
100
+ tmp_ivs = spikes[t, :, :] * 255
101
+ tmp_ivs = cv2.cvtColor(tmp_ivs.astype(np.uint8), cv2.COLOR_GRAY2BGR)
102
+
103
+ if len(gt_boxes) > 0:
104
+ if str(t) in gt_boxes:
105
+ gts = gt_boxes[str(t)]
106
+ gt_num = len(gts)
107
+ for i in range(gt_num):
108
+ box = gts[i]
109
+ box_id = box[0]
110
+ cv2.rectangle(tmp_ivs, (int(box[2]), int(box[1])),
111
+ (int(box[2] + box[4]), int(box[1] + box[3])),
112
+ (int(255), int(255), int(255)), 2)
113
+
114
+ if str(t) in test_boxes:
115
+ test = test_boxes[str(t)]
116
+ test_num = len(test)
117
+ for i in range(test_num):
118
+ box = test[i]
119
+ box_id = box[0]
120
+ colors = color_dict[box_id]
121
+ cv2.rectangle(tmp_ivs, (int(box[2]), int(box[1])),
122
+ (int(box[2] + box[4]), int(box[1] + box[3])),
123
+ (int(colors[0]), int(colors[1]), int(colors[2])), 2)
124
+
125
+ mov.write(tmp_ivs)
126
+
127
+ mov.release()
128
+
129
+
130
+ def obtain_detection_video(spikes, video_filename, res_filepath, evaluate_seq_len, begin_idx=0, **dataDict):
131
+ spike_h = dataDict.get('spike_h')
132
+ spike_w = dataDict.get('spike_w')
133
+
134
+ gt_file = dataDict.get('labeled_data_dir')
135
+ gt_boxes = {}
136
+ if gt_file is not None:
137
+ start_idx = begin_idx
138
+ end_idx = begin_idx + evaluate_seq_len
139
+ for seq_no in range(start_idx, end_idx):
140
+ gt_filename = gt_file[seq_no]
141
+ gt_f = open(gt_filename, 'r')
142
+
143
+ gt_lines = gt_f.readlines()
144
+ for line in gt_lines:
145
+ tmp_box = line.split(',')
146
+
147
+ x = float(tmp_box[0])
148
+ y = float(tmp_box[1])
149
+ w = float(tmp_box[2])
150
+ h = float(tmp_box[3])
151
+ box_id = int(0)
152
+
153
+ if str(seq_no) not in gt_boxes:
154
+ gt_boxes[str(seq_no)] = []
155
+ bbox = [box_id, x, y, w, h]
156
+ gt_boxes[str(seq_no)].append(bbox)
157
+
158
+ gt_f.close()
159
+
160
+ result_file = res_filepath
161
+ test_boxes = {}
162
+ result_f = open(result_file, 'r')
163
+ result_lines = result_f.readlines()
164
+ color_dict = {}
165
+
166
+ for line in result_lines:
167
+ res_box = line.split(',')
168
+ time_step = res_box[0]
169
+ track_id = res_box[1]
170
+ if track_id not in color_dict.keys():
171
+ colors = (np.random.rand(1, 3) * 255).astype(np.uint8)
172
+ color_dict[track_id] = np.squeeze(colors)
173
+
174
+ x = float(res_box[2])
175
+ y = float(res_box[3])
176
+ w = float(res_box[4])
177
+ h = float(res_box[5])
178
+
179
+ if str(time_step) not in test_boxes:
180
+ test_boxes[str(time_step)] = []
181
+
182
+ test_box = [track_id, x, y, w, h]
183
+ test_boxes[str(time_step)].append(test_box)
184
+
185
+ result_f.close()
186
+
187
+ mov = cv2.VideoWriter(video_filename, cv2.VideoWriter_fourcc(*'MJPG'), 30, (spike_w, spike_h))
188
+
189
+ block_len = spikes.shape[0]
190
+ # gt_intv = int(block_len/evaluate_seq_len)
191
+ gt_intv = 400
192
+
193
+ # for t in range(150, block_len):
194
+ for i_gt in range(start_idx+1, end_idx):
195
+ t = i_gt * gt_intv + int(gt_intv/2)
196
+ tmp_ivs = spikes[t, :, :] * 255
197
+ tmp_ivs = cv2.cvtColor(tmp_ivs.astype(np.uint8), cv2.COLOR_GRAY2BGR)
198
+
199
+ if len(gt_boxes) > 0:
200
+ gts = gt_boxes[str(i_gt)]
201
+ gt_num = len(gts)
202
+ for i in range(gt_num):
203
+ box = gts[i]
204
+ cv2.rectangle(tmp_ivs, (int(spike_w - box[1]), int(box[2])),
205
+ (int(spike_w - box[1] - box[3]), int(box[2] + box[4])),
206
+ (int(255), int(255), int(255)), 2)
207
+
208
+ if str(t) in test_boxes:
209
+ test = test_boxes[str(t)]
210
+ test_num = len(test)
211
+ for i in range(test_num):
212
+ box = test[i]
213
+ box_id = box[0]
214
+ colors = color_dict[box_id]
215
+ cv2.rectangle(tmp_ivs, (int(box[2]), int(box[1])),
216
+ (int(box[2] + box[4]), int(box[1] + box[3])),
217
+ (int(colors[0]), int(colors[1]), int(colors[2])), 2)
218
+
219
+ mov.write(tmp_ivs)
220
+
221
+ mov.release()
222
+
223
+ def get_heatVideo(results, video_filename):
224
+ results = np.array(results)
225
+ frame_num = results.shape[0]
226
+ frames = []
227
+
228
+ fig = plt.figure()
229
+ for i in range(frame_num):
230
+ tmp_res = results[i]
231
+ # frames.append([plt.imshow(tmp_res, cmap=cm.Greys_r, animated=True)])
232
+ frames.append([plt.imshow(tmp_res, cmap=cm.Blues, animated=True)])
233
+
234
+ ani = animation.ArtistAnimation(fig, frames, interval=50, blit=True,
235
+ repeat_delay=1000)
236
+
237
+ # change the path to where you save ffmpeg
238
+ plt.rcParams['animation.ffmpeg_path'] = 'F:\\ffmpeg-N-99818-g993429cfb4-win64-gpl-shared-vulkan\\bin\\ffmpeg.exe'
239
+ FFwrite = animation.FFMpegWriter(fps=30, extra_args=['-vcodec', 'libx264'])
240
+ ani.save(video_filename, writer=FFwrite)
241
+ plt.show()
242
+
243
+
244
+
245
+
snnTracker/visualization/optical_flow_visualization.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2022/7/21
3
+ # @Author : Rui Zhao
4
+ # @File : optical_flow_visualization.py
5
+
6
+ import torch
7
+ import numpy as np
8
+ import cv2
9
+ import math
10
+
11
+
12
+ #################### Interface ####################
13
+ def flow_visualization(flow, mode='normal', use_cv2=True):
14
+ if mode == 'normal':
15
+ flow_vis = flow_to_image(flow_uv=flow, convert_to_bgr=use_cv2)
16
+ elif mode == 'scflow':
17
+ flow_vis = flow_to_img_scflow(flow_uv=flow)
18
+ if not use_cv2:
19
+ flow_vis = cv2.cvtColor(flow_vis, cv2.COLOR_BGR2RGB)
20
+ elif mode == 'evflow':
21
+ flow_vis = flow_viz_np(flow_x=flow[:,:,0], flow_y=flow[:,:,1])
22
+
23
+ return flow_vis
24
+
25
+
26
+ def vis_color_map(use_cv2=True):
27
+ u = np.linspace(-100, 99, 200)
28
+ v = np.linspace(-100, 99, 200)
29
+ xx, yy = np.meshgrid(u, v)
30
+ flow = np.concatenate((xx[:,:,None], yy[:,:,None]), axis=2)
31
+ map_normal = flow_visualization(flow=flow, mode='normal', use_cv2=use_cv2)
32
+ map_scflow = flow_visualization(flow=flow, mode='scflow', use_cv2=use_cv2)
33
+ map_evflow = flow_visualization(flow=flow, mode='evflow', use_cv2=use_cv2)
34
+ return [map_normal, map_scflow, map_evflow]
35
+
36
+
37
+ def make_colorwheel():
38
+ """
39
+ Generates a color wheel for optical flow visualization as presented in:
40
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
41
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
42
+
43
+ Code follows the original C++ source code of Daniel Scharstein.
44
+ Code follows the the Matlab source code of Deqing Sun.
45
+
46
+ Returns:
47
+ np.ndarray: Color wheel
48
+ """
49
+
50
+ RY = 15
51
+ YG = 6
52
+ GC = 4
53
+ CB = 11
54
+ BM = 13
55
+ MR = 6
56
+
57
+ ncols = RY + YG + GC + CB + BM + MR
58
+ colorwheel = np.zeros((ncols, 3))
59
+ col = 0
60
+
61
+ # RY
62
+ colorwheel[0:RY, 0] = 255
63
+ colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
64
+ col = col+RY
65
+ # YG
66
+ colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
67
+ colorwheel[col:col+YG, 1] = 255
68
+ col = col+YG
69
+ # GC
70
+ colorwheel[col:col+GC, 1] = 255
71
+ colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
72
+ col = col+GC
73
+ # CB
74
+ colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
75
+ colorwheel[col:col+CB, 2] = 255
76
+ col = col+CB
77
+ # BM
78
+ colorwheel[col:col+BM, 2] = 255
79
+ colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
80
+ col = col+BM
81
+ # MR
82
+ colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
83
+ colorwheel[col:col+MR, 0] = 255
84
+ return colorwheel
85
+
86
+
87
+ #################### Normal Version ####################
88
+ """
89
+ From https://github.com/princeton-vl/RAFT/blob/master/core/utils/flow_viz.py
90
+ """
91
+ def flow_uv_to_colors(u, v, convert_to_bgr=False):
92
+ """
93
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
94
+ According to the C++ source code of Daniel Scharstein
95
+ According to the Matlab source code of Deqing Sun
96
+ Args:
97
+ u (np.ndarray): Input horizontal flow of shape [H,W]
98
+ v (np.ndarray): Input vertical flow of shape [H,W]
99
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
100
+ Returns:
101
+ np.ndarray: Flow visualization image of shape [H,W,3]
102
+ """
103
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
104
+ colorwheel = make_colorwheel() # shape [55x3]
105
+ ncols = colorwheel.shape[0]
106
+ rad = np.sqrt(np.square(u) + np.square(v))
107
+ a = np.arctan2(-v, -u)/np.pi
108
+ fk = (a+1) / 2*(ncols-1)
109
+ k0 = np.floor(fk).astype(np.int32)
110
+ k1 = k0 + 1
111
+ k1[k1 == ncols] = 0
112
+ f = fk - k0
113
+ for i in range(colorwheel.shape[1]):
114
+ tmp = colorwheel[:,i]
115
+ col0 = tmp[k0] / 255.0
116
+ col1 = tmp[k1] / 255.0
117
+ col = (1-f)*col0 + f*col1
118
+ idx = (rad <= 1)
119
+ col[idx] = 1 - rad[idx] * (1-col[idx])
120
+ col[~idx] = col[~idx] * 0.75 # out of range
121
+ # Note the 2-i => BGR instead of RGB
122
+ ch_idx = 2-i if convert_to_bgr else i
123
+ flow_image[:,:,ch_idx] = np.floor(255 * col)
124
+ return flow_image
125
+
126
+
127
+ def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
128
+ """
129
+ Expects a two dimensional flow image of shape.
130
+ Args:
131
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
132
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
133
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
134
+ Returns:
135
+ np.ndarray: Flow visualization image of shape [H,W,3]
136
+ """
137
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
138
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
139
+ if clip_flow is not None:
140
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
141
+ u = flow_uv[:,:,0]
142
+ v = flow_uv[:,:,1]
143
+ rad = np.sqrt(np.square(u) + np.square(v))
144
+ rad_max = np.max(rad)
145
+ epsilon = 1e-5
146
+ u = u / (rad_max + epsilon)
147
+ v = v / (rad_max + epsilon)
148
+ return flow_uv_to_colors(u, v, convert_to_bgr)
149
+
150
+
151
+ #################### SCFlow Version ####################
152
+ def flow_uv_to_colors_scflow(u, v, convert_to_bgr=False):
153
+ """
154
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
155
+
156
+ According to the C++ source code of Daniel Scharstein
157
+ According to the Matlab source code of Deqing Sun
158
+
159
+ Args:
160
+ u (np.ndarray): Input horizontal flow of shape [H,W]
161
+ v (np.ndarray): Input vertical flow of shape [H,W]
162
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
163
+
164
+ Returns:
165
+ np.ndarray: Flow visualization image of shape [H,W,3]
166
+ """
167
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
168
+ colorwheel = make_colorwheel() # shape [55x3]
169
+ ncols = colorwheel.shape[0]
170
+ rad = np.sqrt(np.square(u) + np.square(v))
171
+ a = np.arctan2(-v, u)/np.pi
172
+ fk = (a+1) / 2*(ncols-1)
173
+ k0 = np.floor(fk).astype(np.int32)
174
+ k1 = k0 + 1
175
+ k1[k1 == ncols] = 0
176
+ f = fk - k0
177
+ for i in range(colorwheel.shape[1]):
178
+ tmp = colorwheel[:,i]
179
+ col0 = tmp[k0] / 255.0
180
+ col1 = tmp[k1] / 255.0
181
+ col = (1-f)*col0 + f*col1
182
+ idx = (rad <= 1)
183
+ col[idx] = 1 - rad[idx] * (1-col[idx])
184
+ col[~idx] = col[~idx] * 0.75 # out of range
185
+ # Note the 2-i => BGR instead of RGB
186
+ ch_idx = 2-i if convert_to_bgr else i
187
+ flow_image[:,:,ch_idx] = np.floor(255 * col)
188
+ return flow_image
189
+
190
+
191
+ def flow_to_img_scflow(flow_uv, clip_flow=None):
192
+ """
193
+ Expects a two dimensional flow image of shape.
194
+
195
+ Args:
196
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
197
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
198
+
199
+ Returns:
200
+ np.ndarray: Flow visualization image of shape [H,W,3]
201
+ """
202
+ convert_to_bgr = False
203
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
204
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
205
+ if clip_flow is not None:
206
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
207
+ u = flow_uv[:,:,0]
208
+ v = flow_uv[:,:,1]
209
+ rad = np.sqrt(np.square(u) + np.square(v))
210
+ rad_max = np.max(rad)
211
+ epsilon = 1e-5
212
+ u = u / (rad_max + epsilon)
213
+ v = v / (rad_max + epsilon)
214
+ return flow_uv_to_colors_scflow(u, v, convert_to_bgr)
215
+
216
+
217
+ #################### EVFlow Version ####################
218
+ """
219
+ From https://github.com/chan8972/Spike-FlowNet/blob/master/vis_utils.py
220
+ """
221
+
222
+ """
223
+ Generates an RGB image where each point corresponds to flow in that direction from the center,
224
+ as visualized by flow_viz_tf.
225
+ Output: color_wheel_rgb: [1, width, height, 3]
226
+ """
227
+ def draw_color_wheel_np(width, height):
228
+ color_wheel_x = np.linspace(-width / 2.,width / 2.,width)
229
+ color_wheel_y = np.linspace(-height / 2.,height / 2.,height)
230
+ color_wheel_X, color_wheel_Y = np.meshgrid(color_wheel_x, color_wheel_y)
231
+ color_wheel_rgb = flow_viz_np(color_wheel_X, color_wheel_Y)
232
+ return color_wheel_rgb
233
+
234
+
235
+ """
236
+ Visualizes optical flow in HSV space using TensorFlow, with orientation as H, magnitude as V.
237
+ Returned as RGB.
238
+ Input: flow: [batch_size, width, height, 2]
239
+ Output: flow_rgb: [batch_size, width, height, 3]
240
+ """
241
+ def flow_viz_np(flow_x, flow_y):
242
+ import cv2
243
+ flows = np.stack((flow_x, flow_y), axis=2)
244
+ mag = np.linalg.norm(flows, axis=2)
245
+
246
+ ang = np.arctan2(flow_y, flow_x)
247
+ ang += np.pi
248
+ ang *= 180. / np.pi / 2.
249
+ ang = ang.astype(np.uint8)
250
+ hsv = np.zeros([flow_x.shape[0], flow_x.shape[1], 3], dtype=np.uint8)
251
+ hsv[:, :, 0] = ang
252
+ hsv[:, :, 1] = 255
253
+ hsv[:, :, 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
254
+ flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
255
+ return flow_rgb
256
+
257
+
258
+
259
+ #################### Visualization tools when training SCFlow ####################
260
+ def outflow_img(flow_list, vis_path, name_prefix='flow', max_batch=4):
261
+ flow = flow_list[0]
262
+ batch_size, c, h, w = flow.shape
263
+
264
+ for batch in range(batch_size):
265
+ if batch > max_batch:
266
+ break
267
+ flow_current = flow[batch,:,:,:].permute(1,2,0).detach().cpu().numpy()
268
+ flow_img = flow_visualization(flow_current, mode='scflow', use_cv2=True)
269
+
270
+ cv2.imwrite(vis_path + '/{:s}_batch_id={:02d}.png'.format(name_prefix, batch), flow_img)
271
+
272
+ return