Spaces:
Sleeping
Sleeping
update utils generate_embeddings_io
Browse files- utils/utils.py +636 -636
utils/utils.py
CHANGED
@@ -1,636 +1,636 @@
|
|
1 |
-
import os
|
2 |
-
import io
|
3 |
-
import pickle
|
4 |
-
import copy
|
5 |
-
from collections import Counter
|
6 |
-
from pathlib import Path
|
7 |
-
from tempfile import NamedTemporaryFile
|
8 |
-
import regex as re
|
9 |
-
import numpy as np
|
10 |
-
import pandas as pd
|
11 |
-
from sklearn.manifold import TSNE
|
12 |
-
from sklearn.svm import SVC
|
13 |
-
from sklearn.model_selection import train_test_split
|
14 |
-
from sklearn.metrics import accuracy_score, classification_report
|
15 |
-
import torch
|
16 |
-
from tqdm import tqdm
|
17 |
-
from PIL import Image
|
18 |
-
from transformers import AutoProcessor, AutoModel
|
19 |
-
import streamlit as st
|
20 |
-
from .data_loading import load_multiple_annotations, load_multiple_annotations_io
|
21 |
-
from .data_processing import generate_label_array
|
22 |
-
from .seqIo import seqIo_reader
|
23 |
-
from .mp4Io import mp4Io_reader
|
24 |
-
|
25 |
-
SLIP_MODEL_ID = "google/siglip-so400m-patch14-384"
|
26 |
-
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
|
27 |
-
|
28 |
-
def create_annot_fname_dict(annot_fnames: list[str])-> dict:
|
29 |
-
fs = re.compile(r'.*(_\d+)$')
|
30 |
-
|
31 |
-
unique_files = set()
|
32 |
-
for file in annot_fnames:
|
33 |
-
file_name = os.fsdecode(file)
|
34 |
-
base_name, _ = os.path.splitext(file_name)
|
35 |
-
if fs.match(base_name):
|
36 |
-
ind = len(fs.match(base_name).group(1))
|
37 |
-
unique_files.add(base_name[:-ind])
|
38 |
-
else:
|
39 |
-
unique_files.add(base_name)
|
40 |
-
|
41 |
-
annot_fname_dict = {}
|
42 |
-
for unique_file in unique_files:
|
43 |
-
annot_fname_dict.update({unique_file: [file for file in annot_fnames if unique_file in file]})
|
44 |
-
return annot_fname_dict
|
45 |
-
|
46 |
-
def create_annot_fname_dict_io(annot_fnames: list[str], annot_files: list)-> dict:
|
47 |
-
annot_file_dict = {}
|
48 |
-
for file in annot_files:
|
49 |
-
annot_file_dict.update({file.name : file})
|
50 |
-
fs = re.compile(r'.*(_\d+)$')
|
51 |
-
|
52 |
-
unique_files = set()
|
53 |
-
for file in annot_fnames:
|
54 |
-
file_name = os.fsdecode(file)
|
55 |
-
base_name, _ = os.path.splitext(file_name)
|
56 |
-
if fs.match(base_name):
|
57 |
-
ind = len(fs.match(base_name).group(1))
|
58 |
-
unique_files.add(base_name[:-ind])
|
59 |
-
else:
|
60 |
-
unique_files.add(base_name)
|
61 |
-
|
62 |
-
annot_fname_dict = {}
|
63 |
-
for unique_file in unique_files:
|
64 |
-
annot_list = [file for file in annot_fnames if unique_file in file]
|
65 |
-
annot_list.sort()
|
66 |
-
annot_file_list = [annot_file_dict[annot_file_name] for annot_file_name in annot_list]
|
67 |
-
annot_fname_dict.update({unique_file: annot_file_list})
|
68 |
-
return annot_fname_dict
|
69 |
-
|
70 |
-
def get_io_reader(uploaded_file):
|
71 |
-
assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
|
72 |
-
with NamedTemporaryFile(suffix="seq", delete=False) as temp:
|
73 |
-
temp.write(uploaded_file.getvalue())
|
74 |
-
sr = seqIo_reader(temp.name)
|
75 |
-
return sr
|
76 |
-
|
77 |
-
def load_slip_model(device):
|
78 |
-
return AutoModel.from_pretrained(SLIP_MODEL_ID).to(device)
|
79 |
-
|
80 |
-
def load_slip_preprocessor():
|
81 |
-
return AutoProcessor.from_pretrained(SLIP_MODEL_ID)
|
82 |
-
|
83 |
-
def load_clip_model(device):
|
84 |
-
return AutoModel.from_pretrained(CLIP_MODEL_ID).to(device)
|
85 |
-
|
86 |
-
def load_clip_preprocessor():
|
87 |
-
return AutoProcessor.from_pretrained(CLIP_MODEL_ID)
|
88 |
-
|
89 |
-
def encode_image(image, device, model, processor):
|
90 |
-
with torch.no_grad():
|
91 |
-
#convert_models_to_fp32(model)
|
92 |
-
inputs = processor(images=image, return_tensors="pt").to(device)
|
93 |
-
image_features = model.get_image_features(**inputs)
|
94 |
-
return image_features.cpu().numpy().flatten()
|
95 |
-
|
96 |
-
def generate_embeddings_stream(fnames : list[str],
|
97 |
-
model = 'SLIP',
|
98 |
-
downsample_rate = 4,
|
99 |
-
save_csv = False)-> tuple[list, list, list]:
|
100 |
-
# set up model and device
|
101 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
102 |
-
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
103 |
-
if model == 'SLIP':
|
104 |
-
embed_model = load_slip_model(device)
|
105 |
-
processor = load_slip_preprocessor()
|
106 |
-
elif model == 'CLIP':
|
107 |
-
embed_model = load_clip_model(device)
|
108 |
-
processor = load_clip_preprocessor()
|
109 |
-
|
110 |
-
all_video_embeddings = []
|
111 |
-
all_video_frames = []
|
112 |
-
for fname in fnames:
|
113 |
-
# read in file
|
114 |
-
is_seq = False
|
115 |
-
if fname[-3:] == 'seq': is_seq = True
|
116 |
-
|
117 |
-
if is_seq:
|
118 |
-
sr = seqIo_reader(fname)
|
119 |
-
else:
|
120 |
-
sr = mp4Io_reader(fname)
|
121 |
-
N = sr.header['numFrames']
|
122 |
-
|
123 |
-
# set up embeddings and frame arrays
|
124 |
-
embeddings = []
|
125 |
-
frames = list(range(N))[::downsample_rate]
|
126 |
-
print(frames)
|
127 |
-
|
128 |
-
# create progress bar
|
129 |
-
i = 0
|
130 |
-
pbar_text = lambda i: f'Creating embeddings for {fname}. {i}/{len(frames)} frames.'
|
131 |
-
pbar = st.progress(0, text=pbar_text(0))
|
132 |
-
|
133 |
-
# convert each frame to embeddings
|
134 |
-
for f in tqdm(frames):
|
135 |
-
img, _ = sr.getFrame(f)
|
136 |
-
img_arr = np.array(img)
|
137 |
-
if is_seq:
|
138 |
-
img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
|
139 |
-
else:
|
140 |
-
img_rgb = Image.fromarray(img_arr).convert('RGB')
|
141 |
-
|
142 |
-
embeddings.append(encode_image(img_rgb, device, embed_model, processor))
|
143 |
-
|
144 |
-
# update progress bar
|
145 |
-
i += 1
|
146 |
-
pbar.progress(i/len(frames), pbar_text(i))
|
147 |
-
|
148 |
-
# save csv of single file
|
149 |
-
if save_csv:
|
150 |
-
df = pd.DataFrame(embeddings)
|
151 |
-
df['Frame'] = frames
|
152 |
-
|
153 |
-
# save csv
|
154 |
-
basename = Path(fname).stem
|
155 |
-
df.to_csv(f'{basename}_embeddings_downsample_{downsample_rate}.csv', index=False)
|
156 |
-
|
157 |
-
all_video_embeddings.append(np.array(embeddings))
|
158 |
-
all_video_frames.append(frames)
|
159 |
-
return all_video_embeddings, all_video_frames
|
160 |
-
|
161 |
-
def get_io_reader(uploaded_file):
|
162 |
-
if uploaded_file.name[-3:]=='seq':
|
163 |
-
with NamedTemporaryFile(suffix="seq", delete=False) as temp:
|
164 |
-
temp.write(uploaded_file.getvalue())
|
165 |
-
sr = seqIo_reader(temp.name)
|
166 |
-
else:
|
167 |
-
with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
|
168 |
-
temp.write(uploaded_file.getvalue())
|
169 |
-
sr = mp4Io_reader(temp.name)
|
170 |
-
return sr
|
171 |
-
|
172 |
-
def generate_embeddings_stream_io(uploaded_files : list,
|
173 |
-
model = 'SLIP',
|
174 |
-
downsample_rate = 4,
|
175 |
-
save_csv = False)-> tuple[list, list, list]:
|
176 |
-
# set up model and device
|
177 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
178 |
-
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
179 |
-
if model == 'SLIP':
|
180 |
-
embed_model = load_slip_model(device)
|
181 |
-
processor = load_slip_preprocessor()
|
182 |
-
elif model == 'CLIP':
|
183 |
-
embed_model = load_clip_model(device)
|
184 |
-
processor = load_clip_preprocessor()
|
185 |
-
|
186 |
-
all_video_embeddings = []
|
187 |
-
all_video_frames = []
|
188 |
-
for file in uploaded_files:
|
189 |
-
is_seq = False
|
190 |
-
if file.name[-3:] == 'seq': is_seq = True
|
191 |
-
|
192 |
-
# read in file
|
193 |
-
sr = get_io_reader(file)
|
194 |
-
N = sr.header['numFrames']
|
195 |
-
|
196 |
-
# set up embeddings and frame arrays
|
197 |
-
embeddings = []
|
198 |
-
frames = list(range(N))[::downsample_rate]
|
199 |
-
print(frames)
|
200 |
-
|
201 |
-
# create progress bar
|
202 |
-
i = 0
|
203 |
-
pbar_text = lambda i: f'Creating embeddings for {file.name}. {i}/{len(frames)} frames.'
|
204 |
-
pbar = st.progress(0, text=pbar_text(0))
|
205 |
-
|
206 |
-
# convert each frame to embeddings
|
207 |
-
for f in tqdm(frames):
|
208 |
-
img, _ = sr.getFrame(f)
|
209 |
-
img_arr = np.array(img)
|
210 |
-
if is_seq:
|
211 |
-
img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
|
212 |
-
else:
|
213 |
-
img_rgb = Image.fromarray(img_arr).convert('RGB')
|
214 |
-
|
215 |
-
embeddings.append(encode_image(img_rgb, device, embed_model, processor))
|
216 |
-
|
217 |
-
# update progress bar
|
218 |
-
i += 1
|
219 |
-
pbar.progress(i/len(frames), pbar_text(i))
|
220 |
-
|
221 |
-
# save csv of single file
|
222 |
-
if save_csv:
|
223 |
-
df = pd.DataFrame(embeddings)
|
224 |
-
df['Frame'] = frames
|
225 |
-
|
226 |
-
# save csv
|
227 |
-
df.to_csv(f'embeddings_downsample_{downsample_rate}_{
|
228 |
-
|
229 |
-
all_video_embeddings.append(np.array(embeddings))
|
230 |
-
all_video_frames.append(frames)
|
231 |
-
return all_video_embeddings, all_video_frames
|
232 |
-
|
233 |
-
def create_embeddings_csv(out: str,
|
234 |
-
fnames: list[str],
|
235 |
-
embeddings: list[np.ndarray],
|
236 |
-
frames: list[list[int]],
|
237 |
-
annotations: list[list[str]],
|
238 |
-
test_fnames: None | list[str],
|
239 |
-
views: None | list[str],
|
240 |
-
conditions: None | list[str],
|
241 |
-
downsample_rate = 4,
|
242 |
-
filesystem = None):
|
243 |
-
"""
|
244 |
-
Creates a .csv file containing all of the generated embeddings and provived information.
|
245 |
-
|
246 |
-
Parameters:
|
247 |
-
-----------
|
248 |
-
out : str
|
249 |
-
The name of the resulting file.
|
250 |
-
fnames : list[str]
|
251 |
-
Video sources for each of the embedding arrays.
|
252 |
-
embeddings : np.ndarray
|
253 |
-
The generated embeddings from the images.
|
254 |
-
downsample_rate : int
|
255 |
-
The downsample_rate used for generating the embeddings.
|
256 |
-
"""
|
257 |
-
assert len(fnames) == len(embeddings)
|
258 |
-
assert len(embeddings) == len(frames)
|
259 |
-
all_embeddings = np.vstack(embeddings)
|
260 |
-
df = pd.DataFrame(all_embeddings)
|
261 |
-
|
262 |
-
labels = []
|
263 |
-
for i, annot_fnames in enumerate(annotations):
|
264 |
-
_, ext = os.path.splitext(annot_fnames[0])
|
265 |
-
if ext == '.annot':
|
266 |
-
annot, _, _, sr = load_multiple_annotations(annot_fnames, filesystem=filesystem)
|
267 |
-
annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
|
268 |
-
elif ext == '.csv':
|
269 |
-
if not filesystem:
|
270 |
-
annot_df = pd.read_csv(annot_fnames[0], header=None)
|
271 |
-
else:
|
272 |
-
with filesystem.open(annot_fnames[0], 'r') as csv_file:
|
273 |
-
annot_df = pd.read_csv(csv_file, header=None)
|
274 |
-
annot_labels = annot_df[0].to_list()[::downsample_rate]
|
275 |
-
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
|
276 |
-
else:
|
277 |
-
raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
|
278 |
-
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
|
279 |
-
print(annot_labels)
|
280 |
-
labels.append(annot_labels)
|
281 |
-
all_labels = np.hstack(labels)
|
282 |
-
print(len(all_labels))
|
283 |
-
df['Label'] = all_labels
|
284 |
-
|
285 |
-
all_frames = np.hstack(frames)
|
286 |
-
df['Frame'] = all_frames
|
287 |
-
sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
|
288 |
-
all_sources = np.hstack(sources)
|
289 |
-
df['Source'] = all_sources
|
290 |
-
|
291 |
-
if test_fnames:
|
292 |
-
t_split = lambda x: True if x in test_fnames else False
|
293 |
-
test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
|
294 |
-
else:
|
295 |
-
test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
|
296 |
-
all_test = np.hstack(test)
|
297 |
-
df['Test'] = all_test
|
298 |
-
|
299 |
-
if views:
|
300 |
-
view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
301 |
-
else:
|
302 |
-
view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
303 |
-
all_view = np.hstack(view)
|
304 |
-
df['View'] = all_view
|
305 |
-
|
306 |
-
if conditions:
|
307 |
-
condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
308 |
-
else:
|
309 |
-
condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
310 |
-
all_condition = np.hstack(condition)
|
311 |
-
df['Condition'] = all_condition
|
312 |
-
return df
|
313 |
-
|
314 |
-
def create_embeddings_csv_io(out: str,
|
315 |
-
fnames: list[str],
|
316 |
-
embeddings: list[np.ndarray],
|
317 |
-
frames: list[list[int]],
|
318 |
-
annotations: list,
|
319 |
-
test_fnames: None | list[str],
|
320 |
-
views: None | list[str],
|
321 |
-
conditions: None | list[str],
|
322 |
-
downsample_rate = 4):
|
323 |
-
"""
|
324 |
-
Creates a .csv file containing all of the generated embeddings and provived information.
|
325 |
-
|
326 |
-
Parameters:
|
327 |
-
-----------
|
328 |
-
out : str
|
329 |
-
The name of the resulting file.
|
330 |
-
fnames : list[str]
|
331 |
-
Video sources for each of the embedding arrays.
|
332 |
-
embeddings : np.ndarray
|
333 |
-
The generated embeddings from the images.
|
334 |
-
downsample_rate : int
|
335 |
-
The downsample_rate used for generating the embeddings.
|
336 |
-
"""
|
337 |
-
assert len(fnames) == len(embeddings)
|
338 |
-
assert len(embeddings) == len(frames)
|
339 |
-
all_embeddings = np.vstack(embeddings)
|
340 |
-
df = pd.DataFrame(all_embeddings)
|
341 |
-
|
342 |
-
labels = []
|
343 |
-
for i, uploaded_annots in enumerate(annotations):
|
344 |
-
print(i)
|
345 |
-
_, ext = os.path.splitext(uploaded_annots[0].name)
|
346 |
-
if ext == '.annot':
|
347 |
-
annot, _, _, sr = load_multiple_annotations_io(uploaded_annots)
|
348 |
-
annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
|
349 |
-
elif ext == '.csv':
|
350 |
-
annot_df = pd.read_csv(uploaded_annots[0], header=None)
|
351 |
-
annot_labels = annot_df[0].to_list()[::downsample_rate]
|
352 |
-
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
|
353 |
-
else:
|
354 |
-
raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
|
355 |
-
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
|
356 |
-
print(annot_labels)
|
357 |
-
labels.append(annot_labels)
|
358 |
-
all_labels = np.hstack(labels)
|
359 |
-
print(len(all_labels))
|
360 |
-
df['Label'] = all_labels
|
361 |
-
|
362 |
-
all_frames = np.hstack(frames)
|
363 |
-
df['Frame'] = all_frames
|
364 |
-
sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
|
365 |
-
all_sources = np.hstack(sources)
|
366 |
-
df['Source'] = all_sources
|
367 |
-
|
368 |
-
if test_fnames:
|
369 |
-
t_split = lambda x: True if x in test_fnames else False
|
370 |
-
test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
|
371 |
-
else:
|
372 |
-
test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
|
373 |
-
all_test = np.hstack(test)
|
374 |
-
df['Test'] = all_test
|
375 |
-
|
376 |
-
if views:
|
377 |
-
view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
378 |
-
else:
|
379 |
-
view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
380 |
-
all_view = np.hstack(view)
|
381 |
-
df['View'] = all_view
|
382 |
-
|
383 |
-
if conditions:
|
384 |
-
condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
385 |
-
else:
|
386 |
-
condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
387 |
-
all_condition = np.hstack(condition)
|
388 |
-
df['Condition'] = all_condition
|
389 |
-
return df
|
390 |
-
|
391 |
-
def process_dataset_in_mem(embeddings_df: pd.DataFrame,
|
392 |
-
specified_classes=None,
|
393 |
-
classes_to_remove=None,
|
394 |
-
max_class_size=None,
|
395 |
-
animal_state=None,
|
396 |
-
view=None,
|
397 |
-
shuffle_data=False,
|
398 |
-
test_videos=None):
|
399 |
-
"""
|
400 |
-
Processes output generated from embeddings paired with images and behavior labels.
|
401 |
-
|
402 |
-
Parameters:
|
403 |
-
-----------
|
404 |
-
csv_path : str
|
405 |
-
Path to the file containing the original data. This should contain embeddings,
|
406 |
-
a column named `'Label'` and a column named `'Images'`.
|
407 |
-
specified_classes : None | list[str]
|
408 |
-
An optional input. Defines labels which should be kept as is in the `'Label'`
|
409 |
-
column and which should be changed to a default `other` label.
|
410 |
-
classes_to_remove : None | list[str]
|
411 |
-
An optional input. Drops rows from the dataframe which contain a label in the
|
412 |
-
list.
|
413 |
-
max_class_size : None | int
|
414 |
-
An optional input. Determines the maximum amount of rows a single label can
|
415 |
-
appear in for each unique label in the `'Label'` column.
|
416 |
-
animal_state : None | str
|
417 |
-
An optional input. Drops rows from the dataframe which do not contain a match
|
418 |
-
for `animal_state` in the text field within the `'Images'` column.
|
419 |
-
view : None | str
|
420 |
-
An optional input. Drops rows from the dataframe which do not contain a match
|
421 |
-
for `view` in the text field within the `'Images'` column.
|
422 |
-
shuffle_data : bool
|
423 |
-
Determines wether the dataframe should have its rows shuffled.
|
424 |
-
test_videos : None | list[str]
|
425 |
-
An optional input. Determines what rows should be in the `test` dataframe, and
|
426 |
-
which should be in the `train` dataframe. It drops rows from the respective
|
427 |
-
dataframe by keeping or dropping rows which do not contain a match for a `str`
|
428 |
-
in `test_videos` in the text field within the `'Images'` column, respectively.
|
429 |
-
|
430 |
-
Returns:
|
431 |
-
--------
|
432 |
-
balanced_train_embeddings : pandas.DataFrame
|
433 |
-
A processed dataframe whose rows contain the embeddings for each of the images
|
434 |
-
at the corresponding index within `balanced_train_images`.
|
435 |
-
balanced_train_labels : list[str]
|
436 |
-
A list of labels for each of the images at the corresponing index within
|
437 |
-
`balanced_train_images`.
|
438 |
-
balanced_train_images: list[str]
|
439 |
-
A list of paths to images with each image at an index corresponding to a label
|
440 |
-
with the same index in `balanced_train_labels` and the same row index within
|
441 |
-
`balanced_train_embeddings`.
|
442 |
-
test_embeddings : pandas.DataFrame
|
443 |
-
A processed dataframe whose rows contain the embeddings for each of the images
|
444 |
-
at the corresponding index within `test_images`.
|
445 |
-
test_labels : list[str]
|
446 |
-
A list of labels for each of the images at the corresponing index within
|
447 |
-
`test_images`.
|
448 |
-
test_images : list[str]
|
449 |
-
A list of paths to images with each image at an index corresponding to a label
|
450 |
-
with the same index in `test_labels` and the same row index within
|
451 |
-
`test_embeddings`.
|
452 |
-
"""
|
453 |
-
# Convert embeddings, labels, and images to a DataFrame for easy manipulation
|
454 |
-
df = copy.deepcopy(embeddings_df)
|
455 |
-
df_keys = [str(x) for x in df.keys()]
|
456 |
-
#Filter by fed or fasted
|
457 |
-
if 'Condition' in df_keys and animal_state:
|
458 |
-
df = df[df['Condition'].str.contains(animal_state, na=False)]
|
459 |
-
|
460 |
-
if 'View' in df_keys and view:
|
461 |
-
df = df[df['View'].str.contains(view, na=False)]
|
462 |
-
|
463 |
-
# Extract unique video names excluding the frame number
|
464 |
-
#unique_video_names = df['Images'].apply(lambda x: '_'.join(x.split('_')[:-1])).unique()
|
465 |
-
#print("\nUnique video names:\n", unique_video_names)
|
466 |
-
|
467 |
-
if classes_to_remove:
|
468 |
-
df = df[~df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
|
469 |
-
elif classes_to_remove and 'all' in classes_to_remove:
|
470 |
-
df = df[df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
|
471 |
-
|
472 |
-
# Further filter to include only specified_classes
|
473 |
-
if specified_classes:
|
474 |
-
single_match = lambda x: list(set(x.split('||')) & set(specified_classes))[0]
|
475 |
-
df['Label'] = df['Label'].apply(lambda x: single_match(x) if not set(x.split('||')).isdisjoint(specified_classes) else 'other')
|
476 |
-
specified_classes.append('other')
|
477 |
-
|
478 |
-
# Separate the DataFrame into test and training sets based on test_videos
|
479 |
-
if 'Test' in df_keys and test_videos:
|
480 |
-
test_df = df[df['Test']]
|
481 |
-
train_df = df[~df['Test']]
|
482 |
-
elif test_videos:
|
483 |
-
test_df = df[df['Images'].str.contains('|'.join(test_videos), na=False)]
|
484 |
-
train_df = df[~df['Images'].str.contains('|'.join(test_videos), na=False)]
|
485 |
-
else:
|
486 |
-
test_df = pd.DataFrame(columns=df.columns)
|
487 |
-
train_df = df
|
488 |
-
|
489 |
-
# Print the number of frames in each class before balancing
|
490 |
-
label_counts = train_df['Label'].value_counts()
|
491 |
-
print("\nNumber of training frames in each class before balancing:")
|
492 |
-
print(label_counts)
|
493 |
-
|
494 |
-
if max_class_size:
|
495 |
-
balanced_train_df = pd.concat([
|
496 |
-
group.sample(n=min(len(group), max_class_size), random_state=1)
|
497 |
-
for label, group in train_df.groupby('Label')
|
498 |
-
])
|
499 |
-
else:
|
500 |
-
balanced_train_df = train_df
|
501 |
-
|
502 |
-
# Shuffle the training DataFrame
|
503 |
-
if shuffle_data:
|
504 |
-
balanced_train_df = balanced_train_df.sample(frac=1).reset_index(drop=True)
|
505 |
-
|
506 |
-
# Convert training set back to numpy array and list
|
507 |
-
if not "Images" in df_keys:
|
508 |
-
balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
|
509 |
-
balanced_train_labels = balanced_train_df['Label'].tolist()
|
510 |
-
balanced_train_images = balanced_train_df['Frame'].tolist()
|
511 |
-
|
512 |
-
# Convert test set back to numpy array and list
|
513 |
-
test_embeddings = test_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
|
514 |
-
test_labels = test_df['Label'].tolist()
|
515 |
-
test_images = test_df['Frame'].tolist()
|
516 |
-
else:
|
517 |
-
# Convert training set back to numpy array and list
|
518 |
-
balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Images']).to_numpy()
|
519 |
-
balanced_train_labels = balanced_train_df['Label'].tolist()
|
520 |
-
balanced_train_images = balanced_train_df['Images'].tolist()
|
521 |
-
|
522 |
-
# Convert test set back to numpy array and list
|
523 |
-
if 'Test' in test_df:
|
524 |
-
test_embeddings = test_df.drop(columns=['Label', 'Images', 'Test']).to_numpy()
|
525 |
-
else:
|
526 |
-
test_embeddings = test_df.drop(columns=['Label', 'Images']).to_numpy()
|
527 |
-
|
528 |
-
test_labels = test_df['Label'].tolist()
|
529 |
-
test_images = test_df['Images'].tolist()
|
530 |
-
|
531 |
-
# Print the number of frames in each class after balancing
|
532 |
-
if specified_classes or max_class_size:
|
533 |
-
balanced_label_counts = Counter(balanced_train_labels)
|
534 |
-
print("\nNumber of training frames in each class after balancing:")
|
535 |
-
print(balanced_label_counts)
|
536 |
-
|
537 |
-
test_label_counts = test_df['Label'].value_counts()
|
538 |
-
# print("\nNumber of testing frames in each class:")
|
539 |
-
print(test_label_counts)
|
540 |
-
|
541 |
-
return balanced_train_embeddings, balanced_train_labels, balanced_train_images, test_embeddings, test_labels, test_images
|
542 |
-
|
543 |
-
def multiclass_merge_and_filter_bouts(multiclass_vector, bout_threshold, proximity_threshold):
|
544 |
-
# Get the unique labels in the multiclass vector (excluding zero, assuming zero is the background/no label)
|
545 |
-
unique_labels = np.unique(multiclass_vector)
|
546 |
-
unique_labels = unique_labels[unique_labels != 0]
|
547 |
-
|
548 |
-
# Initialize a vector to store the merged and filtered multiclass vector
|
549 |
-
merged_vector = np.zeros_like(multiclass_vector)
|
550 |
-
|
551 |
-
for label in unique_labels:
|
552 |
-
# Create a binary vector for the current label
|
553 |
-
binary_vector = (multiclass_vector == label)
|
554 |
-
|
555 |
-
# Find the start and end indices of all sequences of 1's for this label
|
556 |
-
starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
|
557 |
-
ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]
|
558 |
-
|
559 |
-
# Step 1: Merge close short bouts
|
560 |
-
i = 0
|
561 |
-
while i < len(starts) - 1:
|
562 |
-
# Check if the gap between the end of the current bout and the start of the next bout
|
563 |
-
# is within the proximity threshold
|
564 |
-
if starts[i + 1] - ends[i] <= proximity_threshold:
|
565 |
-
# Merge the two bouts by setting all elements between the start of the first
|
566 |
-
# and the end of the second bout to 1
|
567 |
-
binary_vector[ends[i]:starts[i + 1]] = 1
|
568 |
-
# Remove the next bout from consideration
|
569 |
-
starts = np.delete(starts, i + 1)
|
570 |
-
ends = np.delete(ends, i)
|
571 |
-
else:
|
572 |
-
i += 1
|
573 |
-
|
574 |
-
# Update the starts and ends after merging
|
575 |
-
starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
|
576 |
-
ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]
|
577 |
-
|
578 |
-
# Step 2: Remove standalone short bouts
|
579 |
-
for i in range(len(starts)):
|
580 |
-
# Check the length of the bout
|
581 |
-
length_of_bout = ends[i] - starts[i] + 1
|
582 |
-
|
583 |
-
# If the length is less than the threshold, set those elements to 0
|
584 |
-
if length_of_bout < bout_threshold:
|
585 |
-
binary_vector[starts[i]:ends[i] + 1] = 0
|
586 |
-
|
587 |
-
# Combine the binary vector with the merged_vector, ensuring only the current label is set
|
588 |
-
merged_vector[binary_vector] = label
|
589 |
-
|
590 |
-
# Return the filtered multiclass vector
|
591 |
-
return merged_vector
|
592 |
-
|
593 |
-
def get_unique_labels(label_list: list[str]):
|
594 |
-
label_set = set()
|
595 |
-
for label in label_list:
|
596 |
-
individual_labels = label.split('||')
|
597 |
-
for individual_label in individual_labels:
|
598 |
-
label_set.add(individual_label)
|
599 |
-
return list(label_set)
|
600 |
-
|
601 |
-
def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42):
|
602 |
-
return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state)
|
603 |
-
|
604 |
-
def train_model(X_train, y_train, random_state=42):
|
605 |
-
# Train SVM Classifier
|
606 |
-
svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True)
|
607 |
-
svm_clf.fit(X_train, y_train)
|
608 |
-
return svm_clf
|
609 |
-
|
610 |
-
def pickle_model(model):
|
611 |
-
pickled = io.BytesIO()
|
612 |
-
pickle.dump(model, pickled)
|
613 |
-
return pickled
|
614 |
-
|
615 |
-
def get_seq_io_reader(uploaded_file):
|
616 |
-
assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
|
617 |
-
with NamedTemporaryFile(suffix="seq", delete=False) as temp:
|
618 |
-
temp.write(uploaded_file.getvalue())
|
619 |
-
sr = seqIo_reader(temp.name)
|
620 |
-
return sr
|
621 |
-
|
622 |
-
def seq_to_arr(sr):
|
623 |
-
N = sr.header['numFrames']
|
624 |
-
images = []
|
625 |
-
for f in range(N):
|
626 |
-
I, ts = sr.getFrame(f)
|
627 |
-
images.append(I)
|
628 |
-
return np.array(images)
|
629 |
-
|
630 |
-
def get_2d_embedding(embeddings: pd.DataFrame):
|
631 |
-
tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50)
|
632 |
-
embedding_2d = tsne.fit_transform(np.array(embeddings))
|
633 |
-
return embedding_2d
|
634 |
-
|
635 |
-
|
636 |
-
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import pickle
|
4 |
+
import copy
|
5 |
+
from collections import Counter
|
6 |
+
from pathlib import Path
|
7 |
+
from tempfile import NamedTemporaryFile
|
8 |
+
import regex as re
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
from sklearn.manifold import TSNE
|
12 |
+
from sklearn.svm import SVC
|
13 |
+
from sklearn.model_selection import train_test_split
|
14 |
+
from sklearn.metrics import accuracy_score, classification_report
|
15 |
+
import torch
|
16 |
+
from tqdm import tqdm
|
17 |
+
from PIL import Image
|
18 |
+
from transformers import AutoProcessor, AutoModel
|
19 |
+
import streamlit as st
|
20 |
+
from .data_loading import load_multiple_annotations, load_multiple_annotations_io
|
21 |
+
from .data_processing import generate_label_array
|
22 |
+
from .seqIo import seqIo_reader
|
23 |
+
from .mp4Io import mp4Io_reader
|
24 |
+
|
25 |
+
SLIP_MODEL_ID = "google/siglip-so400m-patch14-384"
|
26 |
+
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
|
27 |
+
|
28 |
+
def create_annot_fname_dict(annot_fnames: list[str])-> dict:
|
29 |
+
fs = re.compile(r'.*(_\d+)$')
|
30 |
+
|
31 |
+
unique_files = set()
|
32 |
+
for file in annot_fnames:
|
33 |
+
file_name = os.fsdecode(file)
|
34 |
+
base_name, _ = os.path.splitext(file_name)
|
35 |
+
if fs.match(base_name):
|
36 |
+
ind = len(fs.match(base_name).group(1))
|
37 |
+
unique_files.add(base_name[:-ind])
|
38 |
+
else:
|
39 |
+
unique_files.add(base_name)
|
40 |
+
|
41 |
+
annot_fname_dict = {}
|
42 |
+
for unique_file in unique_files:
|
43 |
+
annot_fname_dict.update({unique_file: [file for file in annot_fnames if unique_file in file]})
|
44 |
+
return annot_fname_dict
|
45 |
+
|
46 |
+
def create_annot_fname_dict_io(annot_fnames: list[str], annot_files: list)-> dict:
|
47 |
+
annot_file_dict = {}
|
48 |
+
for file in annot_files:
|
49 |
+
annot_file_dict.update({file.name : file})
|
50 |
+
fs = re.compile(r'.*(_\d+)$')
|
51 |
+
|
52 |
+
unique_files = set()
|
53 |
+
for file in annot_fnames:
|
54 |
+
file_name = os.fsdecode(file)
|
55 |
+
base_name, _ = os.path.splitext(file_name)
|
56 |
+
if fs.match(base_name):
|
57 |
+
ind = len(fs.match(base_name).group(1))
|
58 |
+
unique_files.add(base_name[:-ind])
|
59 |
+
else:
|
60 |
+
unique_files.add(base_name)
|
61 |
+
|
62 |
+
annot_fname_dict = {}
|
63 |
+
for unique_file in unique_files:
|
64 |
+
annot_list = [file for file in annot_fnames if unique_file in file]
|
65 |
+
annot_list.sort()
|
66 |
+
annot_file_list = [annot_file_dict[annot_file_name] for annot_file_name in annot_list]
|
67 |
+
annot_fname_dict.update({unique_file: annot_file_list})
|
68 |
+
return annot_fname_dict
|
69 |
+
|
70 |
+
def get_io_reader(uploaded_file):
|
71 |
+
assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
|
72 |
+
with NamedTemporaryFile(suffix="seq", delete=False) as temp:
|
73 |
+
temp.write(uploaded_file.getvalue())
|
74 |
+
sr = seqIo_reader(temp.name)
|
75 |
+
return sr
|
76 |
+
|
77 |
+
def load_slip_model(device):
|
78 |
+
return AutoModel.from_pretrained(SLIP_MODEL_ID).to(device)
|
79 |
+
|
80 |
+
def load_slip_preprocessor():
|
81 |
+
return AutoProcessor.from_pretrained(SLIP_MODEL_ID)
|
82 |
+
|
83 |
+
def load_clip_model(device):
|
84 |
+
return AutoModel.from_pretrained(CLIP_MODEL_ID).to(device)
|
85 |
+
|
86 |
+
def load_clip_preprocessor():
|
87 |
+
return AutoProcessor.from_pretrained(CLIP_MODEL_ID)
|
88 |
+
|
89 |
+
def encode_image(image, device, model, processor):
|
90 |
+
with torch.no_grad():
|
91 |
+
#convert_models_to_fp32(model)
|
92 |
+
inputs = processor(images=image, return_tensors="pt").to(device)
|
93 |
+
image_features = model.get_image_features(**inputs)
|
94 |
+
return image_features.cpu().numpy().flatten()
|
95 |
+
|
96 |
+
def generate_embeddings_stream(fnames : list[str],
|
97 |
+
model = 'SLIP',
|
98 |
+
downsample_rate = 4,
|
99 |
+
save_csv = False)-> tuple[list, list, list]:
|
100 |
+
# set up model and device
|
101 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
102 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
103 |
+
if model == 'SLIP':
|
104 |
+
embed_model = load_slip_model(device)
|
105 |
+
processor = load_slip_preprocessor()
|
106 |
+
elif model == 'CLIP':
|
107 |
+
embed_model = load_clip_model(device)
|
108 |
+
processor = load_clip_preprocessor()
|
109 |
+
|
110 |
+
all_video_embeddings = []
|
111 |
+
all_video_frames = []
|
112 |
+
for fname in fnames:
|
113 |
+
# read in file
|
114 |
+
is_seq = False
|
115 |
+
if fname[-3:] == 'seq': is_seq = True
|
116 |
+
|
117 |
+
if is_seq:
|
118 |
+
sr = seqIo_reader(fname)
|
119 |
+
else:
|
120 |
+
sr = mp4Io_reader(fname)
|
121 |
+
N = sr.header['numFrames']
|
122 |
+
|
123 |
+
# set up embeddings and frame arrays
|
124 |
+
embeddings = []
|
125 |
+
frames = list(range(N))[::downsample_rate]
|
126 |
+
print(frames)
|
127 |
+
|
128 |
+
# create progress bar
|
129 |
+
i = 0
|
130 |
+
pbar_text = lambda i: f'Creating embeddings for {fname}. {i}/{len(frames)} frames.'
|
131 |
+
pbar = st.progress(0, text=pbar_text(0))
|
132 |
+
|
133 |
+
# convert each frame to embeddings
|
134 |
+
for f in tqdm(frames):
|
135 |
+
img, _ = sr.getFrame(f)
|
136 |
+
img_arr = np.array(img)
|
137 |
+
if is_seq:
|
138 |
+
img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
|
139 |
+
else:
|
140 |
+
img_rgb = Image.fromarray(img_arr).convert('RGB')
|
141 |
+
|
142 |
+
embeddings.append(encode_image(img_rgb, device, embed_model, processor))
|
143 |
+
|
144 |
+
# update progress bar
|
145 |
+
i += 1
|
146 |
+
pbar.progress(i/len(frames), pbar_text(i))
|
147 |
+
|
148 |
+
# save csv of single file
|
149 |
+
if save_csv:
|
150 |
+
df = pd.DataFrame(embeddings)
|
151 |
+
df['Frame'] = frames
|
152 |
+
|
153 |
+
# save csv
|
154 |
+
basename = Path(fname).stem
|
155 |
+
df.to_csv(f'{basename}_embeddings_downsample_{downsample_rate}.csv', index=False)
|
156 |
+
|
157 |
+
all_video_embeddings.append(np.array(embeddings))
|
158 |
+
all_video_frames.append(frames)
|
159 |
+
return all_video_embeddings, all_video_frames
|
160 |
+
|
161 |
+
def get_io_reader(uploaded_file):
|
162 |
+
if uploaded_file.name[-3:]=='seq':
|
163 |
+
with NamedTemporaryFile(suffix="seq", delete=False) as temp:
|
164 |
+
temp.write(uploaded_file.getvalue())
|
165 |
+
sr = seqIo_reader(temp.name)
|
166 |
+
else:
|
167 |
+
with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
|
168 |
+
temp.write(uploaded_file.getvalue())
|
169 |
+
sr = mp4Io_reader(temp.name)
|
170 |
+
return sr
|
171 |
+
|
172 |
+
def generate_embeddings_stream_io(uploaded_files : list,
|
173 |
+
model = 'SLIP',
|
174 |
+
downsample_rate = 4,
|
175 |
+
save_csv = False)-> tuple[list, list, list]:
|
176 |
+
# set up model and device
|
177 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
178 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
179 |
+
if model == 'SLIP':
|
180 |
+
embed_model = load_slip_model(device)
|
181 |
+
processor = load_slip_preprocessor()
|
182 |
+
elif model == 'CLIP':
|
183 |
+
embed_model = load_clip_model(device)
|
184 |
+
processor = load_clip_preprocessor()
|
185 |
+
|
186 |
+
all_video_embeddings = []
|
187 |
+
all_video_frames = []
|
188 |
+
for file in uploaded_files:
|
189 |
+
is_seq = False
|
190 |
+
if file.name[-3:] == 'seq': is_seq = True
|
191 |
+
|
192 |
+
# read in file
|
193 |
+
sr = get_io_reader(file)
|
194 |
+
N = sr.header['numFrames']
|
195 |
+
|
196 |
+
# set up embeddings and frame arrays
|
197 |
+
embeddings = []
|
198 |
+
frames = list(range(N))[::downsample_rate]
|
199 |
+
print(frames)
|
200 |
+
|
201 |
+
# create progress bar
|
202 |
+
i = 0
|
203 |
+
pbar_text = lambda i: f'Creating embeddings for {file.name}. {i}/{len(frames)} frames.'
|
204 |
+
pbar = st.progress(0, text=pbar_text(0))
|
205 |
+
|
206 |
+
# convert each frame to embeddings
|
207 |
+
for f in tqdm(frames):
|
208 |
+
img, _ = sr.getFrame(f)
|
209 |
+
img_arr = np.array(img)
|
210 |
+
if is_seq:
|
211 |
+
img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
|
212 |
+
else:
|
213 |
+
img_rgb = Image.fromarray(img_arr).convert('RGB')
|
214 |
+
|
215 |
+
embeddings.append(encode_image(img_rgb, device, embed_model, processor))
|
216 |
+
|
217 |
+
# update progress bar
|
218 |
+
i += 1
|
219 |
+
pbar.progress(i/len(frames), pbar_text(i))
|
220 |
+
|
221 |
+
# save csv of single file
|
222 |
+
if save_csv:
|
223 |
+
df = pd.DataFrame(embeddings)
|
224 |
+
df['Frame'] = frames
|
225 |
+
|
226 |
+
# save csv
|
227 |
+
df.to_csv(f'embeddings_downsample_{downsample_rate}_{N}_frames.csv', index=False)
|
228 |
+
|
229 |
+
all_video_embeddings.append(np.array(embeddings))
|
230 |
+
all_video_frames.append(frames)
|
231 |
+
return all_video_embeddings, all_video_frames
|
232 |
+
|
233 |
+
def create_embeddings_csv(out: str,
|
234 |
+
fnames: list[str],
|
235 |
+
embeddings: list[np.ndarray],
|
236 |
+
frames: list[list[int]],
|
237 |
+
annotations: list[list[str]],
|
238 |
+
test_fnames: None | list[str],
|
239 |
+
views: None | list[str],
|
240 |
+
conditions: None | list[str],
|
241 |
+
downsample_rate = 4,
|
242 |
+
filesystem = None):
|
243 |
+
"""
|
244 |
+
Creates a .csv file containing all of the generated embeddings and provived information.
|
245 |
+
|
246 |
+
Parameters:
|
247 |
+
-----------
|
248 |
+
out : str
|
249 |
+
The name of the resulting file.
|
250 |
+
fnames : list[str]
|
251 |
+
Video sources for each of the embedding arrays.
|
252 |
+
embeddings : np.ndarray
|
253 |
+
The generated embeddings from the images.
|
254 |
+
downsample_rate : int
|
255 |
+
The downsample_rate used for generating the embeddings.
|
256 |
+
"""
|
257 |
+
assert len(fnames) == len(embeddings)
|
258 |
+
assert len(embeddings) == len(frames)
|
259 |
+
all_embeddings = np.vstack(embeddings)
|
260 |
+
df = pd.DataFrame(all_embeddings)
|
261 |
+
|
262 |
+
labels = []
|
263 |
+
for i, annot_fnames in enumerate(annotations):
|
264 |
+
_, ext = os.path.splitext(annot_fnames[0])
|
265 |
+
if ext == '.annot':
|
266 |
+
annot, _, _, sr = load_multiple_annotations(annot_fnames, filesystem=filesystem)
|
267 |
+
annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
|
268 |
+
elif ext == '.csv':
|
269 |
+
if not filesystem:
|
270 |
+
annot_df = pd.read_csv(annot_fnames[0], header=None)
|
271 |
+
else:
|
272 |
+
with filesystem.open(annot_fnames[0], 'r') as csv_file:
|
273 |
+
annot_df = pd.read_csv(csv_file, header=None)
|
274 |
+
annot_labels = annot_df[0].to_list()[::downsample_rate]
|
275 |
+
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
|
276 |
+
else:
|
277 |
+
raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
|
278 |
+
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
|
279 |
+
print(annot_labels)
|
280 |
+
labels.append(annot_labels)
|
281 |
+
all_labels = np.hstack(labels)
|
282 |
+
print(len(all_labels))
|
283 |
+
df['Label'] = all_labels
|
284 |
+
|
285 |
+
all_frames = np.hstack(frames)
|
286 |
+
df['Frame'] = all_frames
|
287 |
+
sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
|
288 |
+
all_sources = np.hstack(sources)
|
289 |
+
df['Source'] = all_sources
|
290 |
+
|
291 |
+
if test_fnames:
|
292 |
+
t_split = lambda x: True if x in test_fnames else False
|
293 |
+
test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
|
294 |
+
else:
|
295 |
+
test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
|
296 |
+
all_test = np.hstack(test)
|
297 |
+
df['Test'] = all_test
|
298 |
+
|
299 |
+
if views:
|
300 |
+
view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
301 |
+
else:
|
302 |
+
view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
303 |
+
all_view = np.hstack(view)
|
304 |
+
df['View'] = all_view
|
305 |
+
|
306 |
+
if conditions:
|
307 |
+
condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
308 |
+
else:
|
309 |
+
condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
310 |
+
all_condition = np.hstack(condition)
|
311 |
+
df['Condition'] = all_condition
|
312 |
+
return df
|
313 |
+
|
314 |
+
def create_embeddings_csv_io(out: str,
|
315 |
+
fnames: list[str],
|
316 |
+
embeddings: list[np.ndarray],
|
317 |
+
frames: list[list[int]],
|
318 |
+
annotations: list,
|
319 |
+
test_fnames: None | list[str],
|
320 |
+
views: None | list[str],
|
321 |
+
conditions: None | list[str],
|
322 |
+
downsample_rate = 4):
|
323 |
+
"""
|
324 |
+
Creates a .csv file containing all of the generated embeddings and provived information.
|
325 |
+
|
326 |
+
Parameters:
|
327 |
+
-----------
|
328 |
+
out : str
|
329 |
+
The name of the resulting file.
|
330 |
+
fnames : list[str]
|
331 |
+
Video sources for each of the embedding arrays.
|
332 |
+
embeddings : np.ndarray
|
333 |
+
The generated embeddings from the images.
|
334 |
+
downsample_rate : int
|
335 |
+
The downsample_rate used for generating the embeddings.
|
336 |
+
"""
|
337 |
+
assert len(fnames) == len(embeddings)
|
338 |
+
assert len(embeddings) == len(frames)
|
339 |
+
all_embeddings = np.vstack(embeddings)
|
340 |
+
df = pd.DataFrame(all_embeddings)
|
341 |
+
|
342 |
+
labels = []
|
343 |
+
for i, uploaded_annots in enumerate(annotations):
|
344 |
+
print(i)
|
345 |
+
_, ext = os.path.splitext(uploaded_annots[0].name)
|
346 |
+
if ext == '.annot':
|
347 |
+
annot, _, _, sr = load_multiple_annotations_io(uploaded_annots)
|
348 |
+
annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
|
349 |
+
elif ext == '.csv':
|
350 |
+
annot_df = pd.read_csv(uploaded_annots[0], header=None)
|
351 |
+
annot_labels = annot_df[0].to_list()[::downsample_rate]
|
352 |
+
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
|
353 |
+
else:
|
354 |
+
raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
|
355 |
+
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
|
356 |
+
print(annot_labels)
|
357 |
+
labels.append(annot_labels)
|
358 |
+
all_labels = np.hstack(labels)
|
359 |
+
print(len(all_labels))
|
360 |
+
df['Label'] = all_labels
|
361 |
+
|
362 |
+
all_frames = np.hstack(frames)
|
363 |
+
df['Frame'] = all_frames
|
364 |
+
sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
|
365 |
+
all_sources = np.hstack(sources)
|
366 |
+
df['Source'] = all_sources
|
367 |
+
|
368 |
+
if test_fnames:
|
369 |
+
t_split = lambda x: True if x in test_fnames else False
|
370 |
+
test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
|
371 |
+
else:
|
372 |
+
test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
|
373 |
+
all_test = np.hstack(test)
|
374 |
+
df['Test'] = all_test
|
375 |
+
|
376 |
+
if views:
|
377 |
+
view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
378 |
+
else:
|
379 |
+
view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
380 |
+
all_view = np.hstack(view)
|
381 |
+
df['View'] = all_view
|
382 |
+
|
383 |
+
if conditions:
|
384 |
+
condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
385 |
+
else:
|
386 |
+
condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
|
387 |
+
all_condition = np.hstack(condition)
|
388 |
+
df['Condition'] = all_condition
|
389 |
+
return df
|
390 |
+
|
391 |
+
def process_dataset_in_mem(embeddings_df: pd.DataFrame,
|
392 |
+
specified_classes=None,
|
393 |
+
classes_to_remove=None,
|
394 |
+
max_class_size=None,
|
395 |
+
animal_state=None,
|
396 |
+
view=None,
|
397 |
+
shuffle_data=False,
|
398 |
+
test_videos=None):
|
399 |
+
"""
|
400 |
+
Processes output generated from embeddings paired with images and behavior labels.
|
401 |
+
|
402 |
+
Parameters:
|
403 |
+
-----------
|
404 |
+
csv_path : str
|
405 |
+
Path to the file containing the original data. This should contain embeddings,
|
406 |
+
a column named `'Label'` and a column named `'Images'`.
|
407 |
+
specified_classes : None | list[str]
|
408 |
+
An optional input. Defines labels which should be kept as is in the `'Label'`
|
409 |
+
column and which should be changed to a default `other` label.
|
410 |
+
classes_to_remove : None | list[str]
|
411 |
+
An optional input. Drops rows from the dataframe which contain a label in the
|
412 |
+
list.
|
413 |
+
max_class_size : None | int
|
414 |
+
An optional input. Determines the maximum amount of rows a single label can
|
415 |
+
appear in for each unique label in the `'Label'` column.
|
416 |
+
animal_state : None | str
|
417 |
+
An optional input. Drops rows from the dataframe which do not contain a match
|
418 |
+
for `animal_state` in the text field within the `'Images'` column.
|
419 |
+
view : None | str
|
420 |
+
An optional input. Drops rows from the dataframe which do not contain a match
|
421 |
+
for `view` in the text field within the `'Images'` column.
|
422 |
+
shuffle_data : bool
|
423 |
+
Determines wether the dataframe should have its rows shuffled.
|
424 |
+
test_videos : None | list[str]
|
425 |
+
An optional input. Determines what rows should be in the `test` dataframe, and
|
426 |
+
which should be in the `train` dataframe. It drops rows from the respective
|
427 |
+
dataframe by keeping or dropping rows which do not contain a match for a `str`
|
428 |
+
in `test_videos` in the text field within the `'Images'` column, respectively.
|
429 |
+
|
430 |
+
Returns:
|
431 |
+
--------
|
432 |
+
balanced_train_embeddings : pandas.DataFrame
|
433 |
+
A processed dataframe whose rows contain the embeddings for each of the images
|
434 |
+
at the corresponding index within `balanced_train_images`.
|
435 |
+
balanced_train_labels : list[str]
|
436 |
+
A list of labels for each of the images at the corresponing index within
|
437 |
+
`balanced_train_images`.
|
438 |
+
balanced_train_images: list[str]
|
439 |
+
A list of paths to images with each image at an index corresponding to a label
|
440 |
+
with the same index in `balanced_train_labels` and the same row index within
|
441 |
+
`balanced_train_embeddings`.
|
442 |
+
test_embeddings : pandas.DataFrame
|
443 |
+
A processed dataframe whose rows contain the embeddings for each of the images
|
444 |
+
at the corresponding index within `test_images`.
|
445 |
+
test_labels : list[str]
|
446 |
+
A list of labels for each of the images at the corresponing index within
|
447 |
+
`test_images`.
|
448 |
+
test_images : list[str]
|
449 |
+
A list of paths to images with each image at an index corresponding to a label
|
450 |
+
with the same index in `test_labels` and the same row index within
|
451 |
+
`test_embeddings`.
|
452 |
+
"""
|
453 |
+
# Convert embeddings, labels, and images to a DataFrame for easy manipulation
|
454 |
+
df = copy.deepcopy(embeddings_df)
|
455 |
+
df_keys = [str(x) for x in df.keys()]
|
456 |
+
#Filter by fed or fasted
|
457 |
+
if 'Condition' in df_keys and animal_state:
|
458 |
+
df = df[df['Condition'].str.contains(animal_state, na=False)]
|
459 |
+
|
460 |
+
if 'View' in df_keys and view:
|
461 |
+
df = df[df['View'].str.contains(view, na=False)]
|
462 |
+
|
463 |
+
# Extract unique video names excluding the frame number
|
464 |
+
#unique_video_names = df['Images'].apply(lambda x: '_'.join(x.split('_')[:-1])).unique()
|
465 |
+
#print("\nUnique video names:\n", unique_video_names)
|
466 |
+
|
467 |
+
if classes_to_remove:
|
468 |
+
df = df[~df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
|
469 |
+
elif classes_to_remove and 'all' in classes_to_remove:
|
470 |
+
df = df[df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
|
471 |
+
|
472 |
+
# Further filter to include only specified_classes
|
473 |
+
if specified_classes:
|
474 |
+
single_match = lambda x: list(set(x.split('||')) & set(specified_classes))[0]
|
475 |
+
df['Label'] = df['Label'].apply(lambda x: single_match(x) if not set(x.split('||')).isdisjoint(specified_classes) else 'other')
|
476 |
+
specified_classes.append('other')
|
477 |
+
|
478 |
+
# Separate the DataFrame into test and training sets based on test_videos
|
479 |
+
if 'Test' in df_keys and test_videos:
|
480 |
+
test_df = df[df['Test']]
|
481 |
+
train_df = df[~df['Test']]
|
482 |
+
elif test_videos:
|
483 |
+
test_df = df[df['Images'].str.contains('|'.join(test_videos), na=False)]
|
484 |
+
train_df = df[~df['Images'].str.contains('|'.join(test_videos), na=False)]
|
485 |
+
else:
|
486 |
+
test_df = pd.DataFrame(columns=df.columns)
|
487 |
+
train_df = df
|
488 |
+
|
489 |
+
# Print the number of frames in each class before balancing
|
490 |
+
label_counts = train_df['Label'].value_counts()
|
491 |
+
print("\nNumber of training frames in each class before balancing:")
|
492 |
+
print(label_counts)
|
493 |
+
|
494 |
+
if max_class_size:
|
495 |
+
balanced_train_df = pd.concat([
|
496 |
+
group.sample(n=min(len(group), max_class_size), random_state=1)
|
497 |
+
for label, group in train_df.groupby('Label')
|
498 |
+
])
|
499 |
+
else:
|
500 |
+
balanced_train_df = train_df
|
501 |
+
|
502 |
+
# Shuffle the training DataFrame
|
503 |
+
if shuffle_data:
|
504 |
+
balanced_train_df = balanced_train_df.sample(frac=1).reset_index(drop=True)
|
505 |
+
|
506 |
+
# Convert training set back to numpy array and list
|
507 |
+
if not "Images" in df_keys:
|
508 |
+
balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
|
509 |
+
balanced_train_labels = balanced_train_df['Label'].tolist()
|
510 |
+
balanced_train_images = balanced_train_df['Frame'].tolist()
|
511 |
+
|
512 |
+
# Convert test set back to numpy array and list
|
513 |
+
test_embeddings = test_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
|
514 |
+
test_labels = test_df['Label'].tolist()
|
515 |
+
test_images = test_df['Frame'].tolist()
|
516 |
+
else:
|
517 |
+
# Convert training set back to numpy array and list
|
518 |
+
balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Images']).to_numpy()
|
519 |
+
balanced_train_labels = balanced_train_df['Label'].tolist()
|
520 |
+
balanced_train_images = balanced_train_df['Images'].tolist()
|
521 |
+
|
522 |
+
# Convert test set back to numpy array and list
|
523 |
+
if 'Test' in test_df:
|
524 |
+
test_embeddings = test_df.drop(columns=['Label', 'Images', 'Test']).to_numpy()
|
525 |
+
else:
|
526 |
+
test_embeddings = test_df.drop(columns=['Label', 'Images']).to_numpy()
|
527 |
+
|
528 |
+
test_labels = test_df['Label'].tolist()
|
529 |
+
test_images = test_df['Images'].tolist()
|
530 |
+
|
531 |
+
# Print the number of frames in each class after balancing
|
532 |
+
if specified_classes or max_class_size:
|
533 |
+
balanced_label_counts = Counter(balanced_train_labels)
|
534 |
+
print("\nNumber of training frames in each class after balancing:")
|
535 |
+
print(balanced_label_counts)
|
536 |
+
|
537 |
+
test_label_counts = test_df['Label'].value_counts()
|
538 |
+
# print("\nNumber of testing frames in each class:")
|
539 |
+
print(test_label_counts)
|
540 |
+
|
541 |
+
return balanced_train_embeddings, balanced_train_labels, balanced_train_images, test_embeddings, test_labels, test_images
|
542 |
+
|
543 |
+
def multiclass_merge_and_filter_bouts(multiclass_vector, bout_threshold, proximity_threshold):
|
544 |
+
# Get the unique labels in the multiclass vector (excluding zero, assuming zero is the background/no label)
|
545 |
+
unique_labels = np.unique(multiclass_vector)
|
546 |
+
unique_labels = unique_labels[unique_labels != 0]
|
547 |
+
|
548 |
+
# Initialize a vector to store the merged and filtered multiclass vector
|
549 |
+
merged_vector = np.zeros_like(multiclass_vector)
|
550 |
+
|
551 |
+
for label in unique_labels:
|
552 |
+
# Create a binary vector for the current label
|
553 |
+
binary_vector = (multiclass_vector == label)
|
554 |
+
|
555 |
+
# Find the start and end indices of all sequences of 1's for this label
|
556 |
+
starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
|
557 |
+
ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]
|
558 |
+
|
559 |
+
# Step 1: Merge close short bouts
|
560 |
+
i = 0
|
561 |
+
while i < len(starts) - 1:
|
562 |
+
# Check if the gap between the end of the current bout and the start of the next bout
|
563 |
+
# is within the proximity threshold
|
564 |
+
if starts[i + 1] - ends[i] <= proximity_threshold:
|
565 |
+
# Merge the two bouts by setting all elements between the start of the first
|
566 |
+
# and the end of the second bout to 1
|
567 |
+
binary_vector[ends[i]:starts[i + 1]] = 1
|
568 |
+
# Remove the next bout from consideration
|
569 |
+
starts = np.delete(starts, i + 1)
|
570 |
+
ends = np.delete(ends, i)
|
571 |
+
else:
|
572 |
+
i += 1
|
573 |
+
|
574 |
+
# Update the starts and ends after merging
|
575 |
+
starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
|
576 |
+
ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]
|
577 |
+
|
578 |
+
# Step 2: Remove standalone short bouts
|
579 |
+
for i in range(len(starts)):
|
580 |
+
# Check the length of the bout
|
581 |
+
length_of_bout = ends[i] - starts[i] + 1
|
582 |
+
|
583 |
+
# If the length is less than the threshold, set those elements to 0
|
584 |
+
if length_of_bout < bout_threshold:
|
585 |
+
binary_vector[starts[i]:ends[i] + 1] = 0
|
586 |
+
|
587 |
+
# Combine the binary vector with the merged_vector, ensuring only the current label is set
|
588 |
+
merged_vector[binary_vector] = label
|
589 |
+
|
590 |
+
# Return the filtered multiclass vector
|
591 |
+
return merged_vector
|
592 |
+
|
593 |
+
def get_unique_labels(label_list: list[str]):
|
594 |
+
label_set = set()
|
595 |
+
for label in label_list:
|
596 |
+
individual_labels = label.split('||')
|
597 |
+
for individual_label in individual_labels:
|
598 |
+
label_set.add(individual_label)
|
599 |
+
return list(label_set)
|
600 |
+
|
601 |
+
def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42):
|
602 |
+
return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state)
|
603 |
+
|
604 |
+
def train_model(X_train, y_train, random_state=42):
|
605 |
+
# Train SVM Classifier
|
606 |
+
svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True)
|
607 |
+
svm_clf.fit(X_train, y_train)
|
608 |
+
return svm_clf
|
609 |
+
|
610 |
+
def pickle_model(model):
|
611 |
+
pickled = io.BytesIO()
|
612 |
+
pickle.dump(model, pickled)
|
613 |
+
return pickled
|
614 |
+
|
615 |
+
def get_seq_io_reader(uploaded_file):
|
616 |
+
assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
|
617 |
+
with NamedTemporaryFile(suffix="seq", delete=False) as temp:
|
618 |
+
temp.write(uploaded_file.getvalue())
|
619 |
+
sr = seqIo_reader(temp.name)
|
620 |
+
return sr
|
621 |
+
|
622 |
+
def seq_to_arr(sr):
|
623 |
+
N = sr.header['numFrames']
|
624 |
+
images = []
|
625 |
+
for f in range(N):
|
626 |
+
I, ts = sr.getFrame(f)
|
627 |
+
images.append(I)
|
628 |
+
return np.array(images)
|
629 |
+
|
630 |
+
def get_2d_embedding(embeddings: pd.DataFrame):
|
631 |
+
tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50)
|
632 |
+
embedding_2d = tsne.fit_transform(np.array(embeddings))
|
633 |
+
return embedding_2d
|
634 |
+
|
635 |
+
|
636 |
+
|