ncoria commited on
Commit
3f0e895
·
verified ·
1 Parent(s): da69119

update utils generate_embeddings_io

Browse files
Files changed (1) hide show
  1. utils/utils.py +636 -636
utils/utils.py CHANGED
@@ -1,636 +1,636 @@
1
- import os
2
- import io
3
- import pickle
4
- import copy
5
- from collections import Counter
6
- from pathlib import Path
7
- from tempfile import NamedTemporaryFile
8
- import regex as re
9
- import numpy as np
10
- import pandas as pd
11
- from sklearn.manifold import TSNE
12
- from sklearn.svm import SVC
13
- from sklearn.model_selection import train_test_split
14
- from sklearn.metrics import accuracy_score, classification_report
15
- import torch
16
- from tqdm import tqdm
17
- from PIL import Image
18
- from transformers import AutoProcessor, AutoModel
19
- import streamlit as st
20
- from .data_loading import load_multiple_annotations, load_multiple_annotations_io
21
- from .data_processing import generate_label_array
22
- from .seqIo import seqIo_reader
23
- from .mp4Io import mp4Io_reader
24
-
25
- SLIP_MODEL_ID = "google/siglip-so400m-patch14-384"
26
- CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
27
-
28
- def create_annot_fname_dict(annot_fnames: list[str])-> dict:
29
- fs = re.compile(r'.*(_\d+)$')
30
-
31
- unique_files = set()
32
- for file in annot_fnames:
33
- file_name = os.fsdecode(file)
34
- base_name, _ = os.path.splitext(file_name)
35
- if fs.match(base_name):
36
- ind = len(fs.match(base_name).group(1))
37
- unique_files.add(base_name[:-ind])
38
- else:
39
- unique_files.add(base_name)
40
-
41
- annot_fname_dict = {}
42
- for unique_file in unique_files:
43
- annot_fname_dict.update({unique_file: [file for file in annot_fnames if unique_file in file]})
44
- return annot_fname_dict
45
-
46
- def create_annot_fname_dict_io(annot_fnames: list[str], annot_files: list)-> dict:
47
- annot_file_dict = {}
48
- for file in annot_files:
49
- annot_file_dict.update({file.name : file})
50
- fs = re.compile(r'.*(_\d+)$')
51
-
52
- unique_files = set()
53
- for file in annot_fnames:
54
- file_name = os.fsdecode(file)
55
- base_name, _ = os.path.splitext(file_name)
56
- if fs.match(base_name):
57
- ind = len(fs.match(base_name).group(1))
58
- unique_files.add(base_name[:-ind])
59
- else:
60
- unique_files.add(base_name)
61
-
62
- annot_fname_dict = {}
63
- for unique_file in unique_files:
64
- annot_list = [file for file in annot_fnames if unique_file in file]
65
- annot_list.sort()
66
- annot_file_list = [annot_file_dict[annot_file_name] for annot_file_name in annot_list]
67
- annot_fname_dict.update({unique_file: annot_file_list})
68
- return annot_fname_dict
69
-
70
- def get_io_reader(uploaded_file):
71
- assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
72
- with NamedTemporaryFile(suffix="seq", delete=False) as temp:
73
- temp.write(uploaded_file.getvalue())
74
- sr = seqIo_reader(temp.name)
75
- return sr
76
-
77
- def load_slip_model(device):
78
- return AutoModel.from_pretrained(SLIP_MODEL_ID).to(device)
79
-
80
- def load_slip_preprocessor():
81
- return AutoProcessor.from_pretrained(SLIP_MODEL_ID)
82
-
83
- def load_clip_model(device):
84
- return AutoModel.from_pretrained(CLIP_MODEL_ID).to(device)
85
-
86
- def load_clip_preprocessor():
87
- return AutoProcessor.from_pretrained(CLIP_MODEL_ID)
88
-
89
- def encode_image(image, device, model, processor):
90
- with torch.no_grad():
91
- #convert_models_to_fp32(model)
92
- inputs = processor(images=image, return_tensors="pt").to(device)
93
- image_features = model.get_image_features(**inputs)
94
- return image_features.cpu().numpy().flatten()
95
-
96
- def generate_embeddings_stream(fnames : list[str],
97
- model = 'SLIP',
98
- downsample_rate = 4,
99
- save_csv = False)-> tuple[list, list, list]:
100
- # set up model and device
101
- device = "cuda" if torch.cuda.is_available() else "cpu"
102
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
103
- if model == 'SLIP':
104
- embed_model = load_slip_model(device)
105
- processor = load_slip_preprocessor()
106
- elif model == 'CLIP':
107
- embed_model = load_clip_model(device)
108
- processor = load_clip_preprocessor()
109
-
110
- all_video_embeddings = []
111
- all_video_frames = []
112
- for fname in fnames:
113
- # read in file
114
- is_seq = False
115
- if fname[-3:] == 'seq': is_seq = True
116
-
117
- if is_seq:
118
- sr = seqIo_reader(fname)
119
- else:
120
- sr = mp4Io_reader(fname)
121
- N = sr.header['numFrames']
122
-
123
- # set up embeddings and frame arrays
124
- embeddings = []
125
- frames = list(range(N))[::downsample_rate]
126
- print(frames)
127
-
128
- # create progress bar
129
- i = 0
130
- pbar_text = lambda i: f'Creating embeddings for {fname}. {i}/{len(frames)} frames.'
131
- pbar = st.progress(0, text=pbar_text(0))
132
-
133
- # convert each frame to embeddings
134
- for f in tqdm(frames):
135
- img, _ = sr.getFrame(f)
136
- img_arr = np.array(img)
137
- if is_seq:
138
- img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
139
- else:
140
- img_rgb = Image.fromarray(img_arr).convert('RGB')
141
-
142
- embeddings.append(encode_image(img_rgb, device, embed_model, processor))
143
-
144
- # update progress bar
145
- i += 1
146
- pbar.progress(i/len(frames), pbar_text(i))
147
-
148
- # save csv of single file
149
- if save_csv:
150
- df = pd.DataFrame(embeddings)
151
- df['Frame'] = frames
152
-
153
- # save csv
154
- basename = Path(fname).stem
155
- df.to_csv(f'{basename}_embeddings_downsample_{downsample_rate}.csv', index=False)
156
-
157
- all_video_embeddings.append(np.array(embeddings))
158
- all_video_frames.append(frames)
159
- return all_video_embeddings, all_video_frames
160
-
161
- def get_io_reader(uploaded_file):
162
- if uploaded_file.name[-3:]=='seq':
163
- with NamedTemporaryFile(suffix="seq", delete=False) as temp:
164
- temp.write(uploaded_file.getvalue())
165
- sr = seqIo_reader(temp.name)
166
- else:
167
- with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
168
- temp.write(uploaded_file.getvalue())
169
- sr = mp4Io_reader(temp.name)
170
- return sr
171
-
172
- def generate_embeddings_stream_io(uploaded_files : list,
173
- model = 'SLIP',
174
- downsample_rate = 4,
175
- save_csv = False)-> tuple[list, list, list]:
176
- # set up model and device
177
- device = "cuda" if torch.cuda.is_available() else "cpu"
178
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
179
- if model == 'SLIP':
180
- embed_model = load_slip_model(device)
181
- processor = load_slip_preprocessor()
182
- elif model == 'CLIP':
183
- embed_model = load_clip_model(device)
184
- processor = load_clip_preprocessor()
185
-
186
- all_video_embeddings = []
187
- all_video_frames = []
188
- for file in uploaded_files:
189
- is_seq = False
190
- if file.name[-3:] == 'seq': is_seq = True
191
-
192
- # read in file
193
- sr = get_io_reader(file)
194
- N = sr.header['numFrames']
195
-
196
- # set up embeddings and frame arrays
197
- embeddings = []
198
- frames = list(range(N))[::downsample_rate]
199
- print(frames)
200
-
201
- # create progress bar
202
- i = 0
203
- pbar_text = lambda i: f'Creating embeddings for {file.name}. {i}/{len(frames)} frames.'
204
- pbar = st.progress(0, text=pbar_text(0))
205
-
206
- # convert each frame to embeddings
207
- for f in tqdm(frames):
208
- img, _ = sr.getFrame(f)
209
- img_arr = np.array(img)
210
- if is_seq:
211
- img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
212
- else:
213
- img_rgb = Image.fromarray(img_arr).convert('RGB')
214
-
215
- embeddings.append(encode_image(img_rgb, device, embed_model, processor))
216
-
217
- # update progress bar
218
- i += 1
219
- pbar.progress(i/len(frames), pbar_text(i))
220
-
221
- # save csv of single file
222
- if save_csv:
223
- df = pd.DataFrame(embeddings)
224
- df['Frame'] = frames
225
-
226
- # save csv
227
- df.to_csv(f'embeddings_downsample_{downsample_rate}_{frames}_frames.csv', index=False)
228
-
229
- all_video_embeddings.append(np.array(embeddings))
230
- all_video_frames.append(frames)
231
- return all_video_embeddings, all_video_frames
232
-
233
- def create_embeddings_csv(out: str,
234
- fnames: list[str],
235
- embeddings: list[np.ndarray],
236
- frames: list[list[int]],
237
- annotations: list[list[str]],
238
- test_fnames: None | list[str],
239
- views: None | list[str],
240
- conditions: None | list[str],
241
- downsample_rate = 4,
242
- filesystem = None):
243
- """
244
- Creates a .csv file containing all of the generated embeddings and provived information.
245
-
246
- Parameters:
247
- -----------
248
- out : str
249
- The name of the resulting file.
250
- fnames : list[str]
251
- Video sources for each of the embedding arrays.
252
- embeddings : np.ndarray
253
- The generated embeddings from the images.
254
- downsample_rate : int
255
- The downsample_rate used for generating the embeddings.
256
- """
257
- assert len(fnames) == len(embeddings)
258
- assert len(embeddings) == len(frames)
259
- all_embeddings = np.vstack(embeddings)
260
- df = pd.DataFrame(all_embeddings)
261
-
262
- labels = []
263
- for i, annot_fnames in enumerate(annotations):
264
- _, ext = os.path.splitext(annot_fnames[0])
265
- if ext == '.annot':
266
- annot, _, _, sr = load_multiple_annotations(annot_fnames, filesystem=filesystem)
267
- annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
268
- elif ext == '.csv':
269
- if not filesystem:
270
- annot_df = pd.read_csv(annot_fnames[0], header=None)
271
- else:
272
- with filesystem.open(annot_fnames[0], 'r') as csv_file:
273
- annot_df = pd.read_csv(csv_file, header=None)
274
- annot_labels = annot_df[0].to_list()[::downsample_rate]
275
- assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
276
- else:
277
- raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
278
- assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
279
- print(annot_labels)
280
- labels.append(annot_labels)
281
- all_labels = np.hstack(labels)
282
- print(len(all_labels))
283
- df['Label'] = all_labels
284
-
285
- all_frames = np.hstack(frames)
286
- df['Frame'] = all_frames
287
- sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
288
- all_sources = np.hstack(sources)
289
- df['Source'] = all_sources
290
-
291
- if test_fnames:
292
- t_split = lambda x: True if x in test_fnames else False
293
- test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
294
- else:
295
- test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
296
- all_test = np.hstack(test)
297
- df['Test'] = all_test
298
-
299
- if views:
300
- view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
301
- else:
302
- view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
303
- all_view = np.hstack(view)
304
- df['View'] = all_view
305
-
306
- if conditions:
307
- condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
308
- else:
309
- condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
310
- all_condition = np.hstack(condition)
311
- df['Condition'] = all_condition
312
- return df
313
-
314
- def create_embeddings_csv_io(out: str,
315
- fnames: list[str],
316
- embeddings: list[np.ndarray],
317
- frames: list[list[int]],
318
- annotations: list,
319
- test_fnames: None | list[str],
320
- views: None | list[str],
321
- conditions: None | list[str],
322
- downsample_rate = 4):
323
- """
324
- Creates a .csv file containing all of the generated embeddings and provived information.
325
-
326
- Parameters:
327
- -----------
328
- out : str
329
- The name of the resulting file.
330
- fnames : list[str]
331
- Video sources for each of the embedding arrays.
332
- embeddings : np.ndarray
333
- The generated embeddings from the images.
334
- downsample_rate : int
335
- The downsample_rate used for generating the embeddings.
336
- """
337
- assert len(fnames) == len(embeddings)
338
- assert len(embeddings) == len(frames)
339
- all_embeddings = np.vstack(embeddings)
340
- df = pd.DataFrame(all_embeddings)
341
-
342
- labels = []
343
- for i, uploaded_annots in enumerate(annotations):
344
- print(i)
345
- _, ext = os.path.splitext(uploaded_annots[0].name)
346
- if ext == '.annot':
347
- annot, _, _, sr = load_multiple_annotations_io(uploaded_annots)
348
- annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
349
- elif ext == '.csv':
350
- annot_df = pd.read_csv(uploaded_annots[0], header=None)
351
- annot_labels = annot_df[0].to_list()[::downsample_rate]
352
- assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
353
- else:
354
- raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
355
- assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
356
- print(annot_labels)
357
- labels.append(annot_labels)
358
- all_labels = np.hstack(labels)
359
- print(len(all_labels))
360
- df['Label'] = all_labels
361
-
362
- all_frames = np.hstack(frames)
363
- df['Frame'] = all_frames
364
- sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
365
- all_sources = np.hstack(sources)
366
- df['Source'] = all_sources
367
-
368
- if test_fnames:
369
- t_split = lambda x: True if x in test_fnames else False
370
- test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
371
- else:
372
- test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
373
- all_test = np.hstack(test)
374
- df['Test'] = all_test
375
-
376
- if views:
377
- view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
378
- else:
379
- view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
380
- all_view = np.hstack(view)
381
- df['View'] = all_view
382
-
383
- if conditions:
384
- condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
385
- else:
386
- condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
387
- all_condition = np.hstack(condition)
388
- df['Condition'] = all_condition
389
- return df
390
-
391
- def process_dataset_in_mem(embeddings_df: pd.DataFrame,
392
- specified_classes=None,
393
- classes_to_remove=None,
394
- max_class_size=None,
395
- animal_state=None,
396
- view=None,
397
- shuffle_data=False,
398
- test_videos=None):
399
- """
400
- Processes output generated from embeddings paired with images and behavior labels.
401
-
402
- Parameters:
403
- -----------
404
- csv_path : str
405
- Path to the file containing the original data. This should contain embeddings,
406
- a column named `'Label'` and a column named `'Images'`.
407
- specified_classes : None | list[str]
408
- An optional input. Defines labels which should be kept as is in the `'Label'`
409
- column and which should be changed to a default `other` label.
410
- classes_to_remove : None | list[str]
411
- An optional input. Drops rows from the dataframe which contain a label in the
412
- list.
413
- max_class_size : None | int
414
- An optional input. Determines the maximum amount of rows a single label can
415
- appear in for each unique label in the `'Label'` column.
416
- animal_state : None | str
417
- An optional input. Drops rows from the dataframe which do not contain a match
418
- for `animal_state` in the text field within the `'Images'` column.
419
- view : None | str
420
- An optional input. Drops rows from the dataframe which do not contain a match
421
- for `view` in the text field within the `'Images'` column.
422
- shuffle_data : bool
423
- Determines wether the dataframe should have its rows shuffled.
424
- test_videos : None | list[str]
425
- An optional input. Determines what rows should be in the `test` dataframe, and
426
- which should be in the `train` dataframe. It drops rows from the respective
427
- dataframe by keeping or dropping rows which do not contain a match for a `str`
428
- in `test_videos` in the text field within the `'Images'` column, respectively.
429
-
430
- Returns:
431
- --------
432
- balanced_train_embeddings : pandas.DataFrame
433
- A processed dataframe whose rows contain the embeddings for each of the images
434
- at the corresponding index within `balanced_train_images`.
435
- balanced_train_labels : list[str]
436
- A list of labels for each of the images at the corresponing index within
437
- `balanced_train_images`.
438
- balanced_train_images: list[str]
439
- A list of paths to images with each image at an index corresponding to a label
440
- with the same index in `balanced_train_labels` and the same row index within
441
- `balanced_train_embeddings`.
442
- test_embeddings : pandas.DataFrame
443
- A processed dataframe whose rows contain the embeddings for each of the images
444
- at the corresponding index within `test_images`.
445
- test_labels : list[str]
446
- A list of labels for each of the images at the corresponing index within
447
- `test_images`.
448
- test_images : list[str]
449
- A list of paths to images with each image at an index corresponding to a label
450
- with the same index in `test_labels` and the same row index within
451
- `test_embeddings`.
452
- """
453
- # Convert embeddings, labels, and images to a DataFrame for easy manipulation
454
- df = copy.deepcopy(embeddings_df)
455
- df_keys = [str(x) for x in df.keys()]
456
- #Filter by fed or fasted
457
- if 'Condition' in df_keys and animal_state:
458
- df = df[df['Condition'].str.contains(animal_state, na=False)]
459
-
460
- if 'View' in df_keys and view:
461
- df = df[df['View'].str.contains(view, na=False)]
462
-
463
- # Extract unique video names excluding the frame number
464
- #unique_video_names = df['Images'].apply(lambda x: '_'.join(x.split('_')[:-1])).unique()
465
- #print("\nUnique video names:\n", unique_video_names)
466
-
467
- if classes_to_remove:
468
- df = df[~df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
469
- elif classes_to_remove and 'all' in classes_to_remove:
470
- df = df[df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
471
-
472
- # Further filter to include only specified_classes
473
- if specified_classes:
474
- single_match = lambda x: list(set(x.split('||')) & set(specified_classes))[0]
475
- df['Label'] = df['Label'].apply(lambda x: single_match(x) if not set(x.split('||')).isdisjoint(specified_classes) else 'other')
476
- specified_classes.append('other')
477
-
478
- # Separate the DataFrame into test and training sets based on test_videos
479
- if 'Test' in df_keys and test_videos:
480
- test_df = df[df['Test']]
481
- train_df = df[~df['Test']]
482
- elif test_videos:
483
- test_df = df[df['Images'].str.contains('|'.join(test_videos), na=False)]
484
- train_df = df[~df['Images'].str.contains('|'.join(test_videos), na=False)]
485
- else:
486
- test_df = pd.DataFrame(columns=df.columns)
487
- train_df = df
488
-
489
- # Print the number of frames in each class before balancing
490
- label_counts = train_df['Label'].value_counts()
491
- print("\nNumber of training frames in each class before balancing:")
492
- print(label_counts)
493
-
494
- if max_class_size:
495
- balanced_train_df = pd.concat([
496
- group.sample(n=min(len(group), max_class_size), random_state=1)
497
- for label, group in train_df.groupby('Label')
498
- ])
499
- else:
500
- balanced_train_df = train_df
501
-
502
- # Shuffle the training DataFrame
503
- if shuffle_data:
504
- balanced_train_df = balanced_train_df.sample(frac=1).reset_index(drop=True)
505
-
506
- # Convert training set back to numpy array and list
507
- if not "Images" in df_keys:
508
- balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
509
- balanced_train_labels = balanced_train_df['Label'].tolist()
510
- balanced_train_images = balanced_train_df['Frame'].tolist()
511
-
512
- # Convert test set back to numpy array and list
513
- test_embeddings = test_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
514
- test_labels = test_df['Label'].tolist()
515
- test_images = test_df['Frame'].tolist()
516
- else:
517
- # Convert training set back to numpy array and list
518
- balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Images']).to_numpy()
519
- balanced_train_labels = balanced_train_df['Label'].tolist()
520
- balanced_train_images = balanced_train_df['Images'].tolist()
521
-
522
- # Convert test set back to numpy array and list
523
- if 'Test' in test_df:
524
- test_embeddings = test_df.drop(columns=['Label', 'Images', 'Test']).to_numpy()
525
- else:
526
- test_embeddings = test_df.drop(columns=['Label', 'Images']).to_numpy()
527
-
528
- test_labels = test_df['Label'].tolist()
529
- test_images = test_df['Images'].tolist()
530
-
531
- # Print the number of frames in each class after balancing
532
- if specified_classes or max_class_size:
533
- balanced_label_counts = Counter(balanced_train_labels)
534
- print("\nNumber of training frames in each class after balancing:")
535
- print(balanced_label_counts)
536
-
537
- test_label_counts = test_df['Label'].value_counts()
538
- # print("\nNumber of testing frames in each class:")
539
- print(test_label_counts)
540
-
541
- return balanced_train_embeddings, balanced_train_labels, balanced_train_images, test_embeddings, test_labels, test_images
542
-
543
- def multiclass_merge_and_filter_bouts(multiclass_vector, bout_threshold, proximity_threshold):
544
- # Get the unique labels in the multiclass vector (excluding zero, assuming zero is the background/no label)
545
- unique_labels = np.unique(multiclass_vector)
546
- unique_labels = unique_labels[unique_labels != 0]
547
-
548
- # Initialize a vector to store the merged and filtered multiclass vector
549
- merged_vector = np.zeros_like(multiclass_vector)
550
-
551
- for label in unique_labels:
552
- # Create a binary vector for the current label
553
- binary_vector = (multiclass_vector == label)
554
-
555
- # Find the start and end indices of all sequences of 1's for this label
556
- starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
557
- ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]
558
-
559
- # Step 1: Merge close short bouts
560
- i = 0
561
- while i < len(starts) - 1:
562
- # Check if the gap between the end of the current bout and the start of the next bout
563
- # is within the proximity threshold
564
- if starts[i + 1] - ends[i] <= proximity_threshold:
565
- # Merge the two bouts by setting all elements between the start of the first
566
- # and the end of the second bout to 1
567
- binary_vector[ends[i]:starts[i + 1]] = 1
568
- # Remove the next bout from consideration
569
- starts = np.delete(starts, i + 1)
570
- ends = np.delete(ends, i)
571
- else:
572
- i += 1
573
-
574
- # Update the starts and ends after merging
575
- starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
576
- ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]
577
-
578
- # Step 2: Remove standalone short bouts
579
- for i in range(len(starts)):
580
- # Check the length of the bout
581
- length_of_bout = ends[i] - starts[i] + 1
582
-
583
- # If the length is less than the threshold, set those elements to 0
584
- if length_of_bout < bout_threshold:
585
- binary_vector[starts[i]:ends[i] + 1] = 0
586
-
587
- # Combine the binary vector with the merged_vector, ensuring only the current label is set
588
- merged_vector[binary_vector] = label
589
-
590
- # Return the filtered multiclass vector
591
- return merged_vector
592
-
593
- def get_unique_labels(label_list: list[str]):
594
- label_set = set()
595
- for label in label_list:
596
- individual_labels = label.split('||')
597
- for individual_label in individual_labels:
598
- label_set.add(individual_label)
599
- return list(label_set)
600
-
601
- def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42):
602
- return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state)
603
-
604
- def train_model(X_train, y_train, random_state=42):
605
- # Train SVM Classifier
606
- svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True)
607
- svm_clf.fit(X_train, y_train)
608
- return svm_clf
609
-
610
- def pickle_model(model):
611
- pickled = io.BytesIO()
612
- pickle.dump(model, pickled)
613
- return pickled
614
-
615
- def get_seq_io_reader(uploaded_file):
616
- assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
617
- with NamedTemporaryFile(suffix="seq", delete=False) as temp:
618
- temp.write(uploaded_file.getvalue())
619
- sr = seqIo_reader(temp.name)
620
- return sr
621
-
622
- def seq_to_arr(sr):
623
- N = sr.header['numFrames']
624
- images = []
625
- for f in range(N):
626
- I, ts = sr.getFrame(f)
627
- images.append(I)
628
- return np.array(images)
629
-
630
- def get_2d_embedding(embeddings: pd.DataFrame):
631
- tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50)
632
- embedding_2d = tsne.fit_transform(np.array(embeddings))
633
- return embedding_2d
634
-
635
-
636
-
 
1
+ import os
2
+ import io
3
+ import pickle
4
+ import copy
5
+ from collections import Counter
6
+ from pathlib import Path
7
+ from tempfile import NamedTemporaryFile
8
+ import regex as re
9
+ import numpy as np
10
+ import pandas as pd
11
+ from sklearn.manifold import TSNE
12
+ from sklearn.svm import SVC
13
+ from sklearn.model_selection import train_test_split
14
+ from sklearn.metrics import accuracy_score, classification_report
15
+ import torch
16
+ from tqdm import tqdm
17
+ from PIL import Image
18
+ from transformers import AutoProcessor, AutoModel
19
+ import streamlit as st
20
+ from .data_loading import load_multiple_annotations, load_multiple_annotations_io
21
+ from .data_processing import generate_label_array
22
+ from .seqIo import seqIo_reader
23
+ from .mp4Io import mp4Io_reader
24
+
25
+ SLIP_MODEL_ID = "google/siglip-so400m-patch14-384"
26
+ CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
27
+
28
+ def create_annot_fname_dict(annot_fnames: list[str])-> dict:
29
+ fs = re.compile(r'.*(_\d+)$')
30
+
31
+ unique_files = set()
32
+ for file in annot_fnames:
33
+ file_name = os.fsdecode(file)
34
+ base_name, _ = os.path.splitext(file_name)
35
+ if fs.match(base_name):
36
+ ind = len(fs.match(base_name).group(1))
37
+ unique_files.add(base_name[:-ind])
38
+ else:
39
+ unique_files.add(base_name)
40
+
41
+ annot_fname_dict = {}
42
+ for unique_file in unique_files:
43
+ annot_fname_dict.update({unique_file: [file for file in annot_fnames if unique_file in file]})
44
+ return annot_fname_dict
45
+
46
+ def create_annot_fname_dict_io(annot_fnames: list[str], annot_files: list)-> dict:
47
+ annot_file_dict = {}
48
+ for file in annot_files:
49
+ annot_file_dict.update({file.name : file})
50
+ fs = re.compile(r'.*(_\d+)$')
51
+
52
+ unique_files = set()
53
+ for file in annot_fnames:
54
+ file_name = os.fsdecode(file)
55
+ base_name, _ = os.path.splitext(file_name)
56
+ if fs.match(base_name):
57
+ ind = len(fs.match(base_name).group(1))
58
+ unique_files.add(base_name[:-ind])
59
+ else:
60
+ unique_files.add(base_name)
61
+
62
+ annot_fname_dict = {}
63
+ for unique_file in unique_files:
64
+ annot_list = [file for file in annot_fnames if unique_file in file]
65
+ annot_list.sort()
66
+ annot_file_list = [annot_file_dict[annot_file_name] for annot_file_name in annot_list]
67
+ annot_fname_dict.update({unique_file: annot_file_list})
68
+ return annot_fname_dict
69
+
70
+ def get_io_reader(uploaded_file):
71
+ assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
72
+ with NamedTemporaryFile(suffix="seq", delete=False) as temp:
73
+ temp.write(uploaded_file.getvalue())
74
+ sr = seqIo_reader(temp.name)
75
+ return sr
76
+
77
+ def load_slip_model(device):
78
+ return AutoModel.from_pretrained(SLIP_MODEL_ID).to(device)
79
+
80
+ def load_slip_preprocessor():
81
+ return AutoProcessor.from_pretrained(SLIP_MODEL_ID)
82
+
83
+ def load_clip_model(device):
84
+ return AutoModel.from_pretrained(CLIP_MODEL_ID).to(device)
85
+
86
+ def load_clip_preprocessor():
87
+ return AutoProcessor.from_pretrained(CLIP_MODEL_ID)
88
+
89
+ def encode_image(image, device, model, processor):
90
+ with torch.no_grad():
91
+ #convert_models_to_fp32(model)
92
+ inputs = processor(images=image, return_tensors="pt").to(device)
93
+ image_features = model.get_image_features(**inputs)
94
+ return image_features.cpu().numpy().flatten()
95
+
96
+ def generate_embeddings_stream(fnames : list[str],
97
+ model = 'SLIP',
98
+ downsample_rate = 4,
99
+ save_csv = False)-> tuple[list, list, list]:
100
+ # set up model and device
101
+ device = "cuda" if torch.cuda.is_available() else "cpu"
102
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
103
+ if model == 'SLIP':
104
+ embed_model = load_slip_model(device)
105
+ processor = load_slip_preprocessor()
106
+ elif model == 'CLIP':
107
+ embed_model = load_clip_model(device)
108
+ processor = load_clip_preprocessor()
109
+
110
+ all_video_embeddings = []
111
+ all_video_frames = []
112
+ for fname in fnames:
113
+ # read in file
114
+ is_seq = False
115
+ if fname[-3:] == 'seq': is_seq = True
116
+
117
+ if is_seq:
118
+ sr = seqIo_reader(fname)
119
+ else:
120
+ sr = mp4Io_reader(fname)
121
+ N = sr.header['numFrames']
122
+
123
+ # set up embeddings and frame arrays
124
+ embeddings = []
125
+ frames = list(range(N))[::downsample_rate]
126
+ print(frames)
127
+
128
+ # create progress bar
129
+ i = 0
130
+ pbar_text = lambda i: f'Creating embeddings for {fname}. {i}/{len(frames)} frames.'
131
+ pbar = st.progress(0, text=pbar_text(0))
132
+
133
+ # convert each frame to embeddings
134
+ for f in tqdm(frames):
135
+ img, _ = sr.getFrame(f)
136
+ img_arr = np.array(img)
137
+ if is_seq:
138
+ img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
139
+ else:
140
+ img_rgb = Image.fromarray(img_arr).convert('RGB')
141
+
142
+ embeddings.append(encode_image(img_rgb, device, embed_model, processor))
143
+
144
+ # update progress bar
145
+ i += 1
146
+ pbar.progress(i/len(frames), pbar_text(i))
147
+
148
+ # save csv of single file
149
+ if save_csv:
150
+ df = pd.DataFrame(embeddings)
151
+ df['Frame'] = frames
152
+
153
+ # save csv
154
+ basename = Path(fname).stem
155
+ df.to_csv(f'{basename}_embeddings_downsample_{downsample_rate}.csv', index=False)
156
+
157
+ all_video_embeddings.append(np.array(embeddings))
158
+ all_video_frames.append(frames)
159
+ return all_video_embeddings, all_video_frames
160
+
161
+ def get_io_reader(uploaded_file):
162
+ if uploaded_file.name[-3:]=='seq':
163
+ with NamedTemporaryFile(suffix="seq", delete=False) as temp:
164
+ temp.write(uploaded_file.getvalue())
165
+ sr = seqIo_reader(temp.name)
166
+ else:
167
+ with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
168
+ temp.write(uploaded_file.getvalue())
169
+ sr = mp4Io_reader(temp.name)
170
+ return sr
171
+
172
+ def generate_embeddings_stream_io(uploaded_files : list,
173
+ model = 'SLIP',
174
+ downsample_rate = 4,
175
+ save_csv = False)-> tuple[list, list, list]:
176
+ # set up model and device
177
+ device = "cuda" if torch.cuda.is_available() else "cpu"
178
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
179
+ if model == 'SLIP':
180
+ embed_model = load_slip_model(device)
181
+ processor = load_slip_preprocessor()
182
+ elif model == 'CLIP':
183
+ embed_model = load_clip_model(device)
184
+ processor = load_clip_preprocessor()
185
+
186
+ all_video_embeddings = []
187
+ all_video_frames = []
188
+ for file in uploaded_files:
189
+ is_seq = False
190
+ if file.name[-3:] == 'seq': is_seq = True
191
+
192
+ # read in file
193
+ sr = get_io_reader(file)
194
+ N = sr.header['numFrames']
195
+
196
+ # set up embeddings and frame arrays
197
+ embeddings = []
198
+ frames = list(range(N))[::downsample_rate]
199
+ print(frames)
200
+
201
+ # create progress bar
202
+ i = 0
203
+ pbar_text = lambda i: f'Creating embeddings for {file.name}. {i}/{len(frames)} frames.'
204
+ pbar = st.progress(0, text=pbar_text(0))
205
+
206
+ # convert each frame to embeddings
207
+ for f in tqdm(frames):
208
+ img, _ = sr.getFrame(f)
209
+ img_arr = np.array(img)
210
+ if is_seq:
211
+ img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
212
+ else:
213
+ img_rgb = Image.fromarray(img_arr).convert('RGB')
214
+
215
+ embeddings.append(encode_image(img_rgb, device, embed_model, processor))
216
+
217
+ # update progress bar
218
+ i += 1
219
+ pbar.progress(i/len(frames), pbar_text(i))
220
+
221
+ # save csv of single file
222
+ if save_csv:
223
+ df = pd.DataFrame(embeddings)
224
+ df['Frame'] = frames
225
+
226
+ # save csv
227
+ df.to_csv(f'embeddings_downsample_{downsample_rate}_{N}_frames.csv', index=False)
228
+
229
+ all_video_embeddings.append(np.array(embeddings))
230
+ all_video_frames.append(frames)
231
+ return all_video_embeddings, all_video_frames
232
+
233
+ def create_embeddings_csv(out: str,
234
+ fnames: list[str],
235
+ embeddings: list[np.ndarray],
236
+ frames: list[list[int]],
237
+ annotations: list[list[str]],
238
+ test_fnames: None | list[str],
239
+ views: None | list[str],
240
+ conditions: None | list[str],
241
+ downsample_rate = 4,
242
+ filesystem = None):
243
+ """
244
+ Creates a .csv file containing all of the generated embeddings and provived information.
245
+
246
+ Parameters:
247
+ -----------
248
+ out : str
249
+ The name of the resulting file.
250
+ fnames : list[str]
251
+ Video sources for each of the embedding arrays.
252
+ embeddings : np.ndarray
253
+ The generated embeddings from the images.
254
+ downsample_rate : int
255
+ The downsample_rate used for generating the embeddings.
256
+ """
257
+ assert len(fnames) == len(embeddings)
258
+ assert len(embeddings) == len(frames)
259
+ all_embeddings = np.vstack(embeddings)
260
+ df = pd.DataFrame(all_embeddings)
261
+
262
+ labels = []
263
+ for i, annot_fnames in enumerate(annotations):
264
+ _, ext = os.path.splitext(annot_fnames[0])
265
+ if ext == '.annot':
266
+ annot, _, _, sr = load_multiple_annotations(annot_fnames, filesystem=filesystem)
267
+ annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
268
+ elif ext == '.csv':
269
+ if not filesystem:
270
+ annot_df = pd.read_csv(annot_fnames[0], header=None)
271
+ else:
272
+ with filesystem.open(annot_fnames[0], 'r') as csv_file:
273
+ annot_df = pd.read_csv(csv_file, header=None)
274
+ annot_labels = annot_df[0].to_list()[::downsample_rate]
275
+ assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
276
+ else:
277
+ raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
278
+ assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
279
+ print(annot_labels)
280
+ labels.append(annot_labels)
281
+ all_labels = np.hstack(labels)
282
+ print(len(all_labels))
283
+ df['Label'] = all_labels
284
+
285
+ all_frames = np.hstack(frames)
286
+ df['Frame'] = all_frames
287
+ sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
288
+ all_sources = np.hstack(sources)
289
+ df['Source'] = all_sources
290
+
291
+ if test_fnames:
292
+ t_split = lambda x: True if x in test_fnames else False
293
+ test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
294
+ else:
295
+ test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
296
+ all_test = np.hstack(test)
297
+ df['Test'] = all_test
298
+
299
+ if views:
300
+ view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
301
+ else:
302
+ view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
303
+ all_view = np.hstack(view)
304
+ df['View'] = all_view
305
+
306
+ if conditions:
307
+ condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
308
+ else:
309
+ condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
310
+ all_condition = np.hstack(condition)
311
+ df['Condition'] = all_condition
312
+ return df
313
+
314
+ def create_embeddings_csv_io(out: str,
315
+ fnames: list[str],
316
+ embeddings: list[np.ndarray],
317
+ frames: list[list[int]],
318
+ annotations: list,
319
+ test_fnames: None | list[str],
320
+ views: None | list[str],
321
+ conditions: None | list[str],
322
+ downsample_rate = 4):
323
+ """
324
+ Creates a .csv file containing all of the generated embeddings and provived information.
325
+
326
+ Parameters:
327
+ -----------
328
+ out : str
329
+ The name of the resulting file.
330
+ fnames : list[str]
331
+ Video sources for each of the embedding arrays.
332
+ embeddings : np.ndarray
333
+ The generated embeddings from the images.
334
+ downsample_rate : int
335
+ The downsample_rate used for generating the embeddings.
336
+ """
337
+ assert len(fnames) == len(embeddings)
338
+ assert len(embeddings) == len(frames)
339
+ all_embeddings = np.vstack(embeddings)
340
+ df = pd.DataFrame(all_embeddings)
341
+
342
+ labels = []
343
+ for i, uploaded_annots in enumerate(annotations):
344
+ print(i)
345
+ _, ext = os.path.splitext(uploaded_annots[0].name)
346
+ if ext == '.annot':
347
+ annot, _, _, sr = load_multiple_annotations_io(uploaded_annots)
348
+ annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
349
+ elif ext == '.csv':
350
+ annot_df = pd.read_csv(uploaded_annots[0], header=None)
351
+ annot_labels = annot_df[0].to_list()[::downsample_rate]
352
+ assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
353
+ else:
354
+ raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
355
+ assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
356
+ print(annot_labels)
357
+ labels.append(annot_labels)
358
+ all_labels = np.hstack(labels)
359
+ print(len(all_labels))
360
+ df['Label'] = all_labels
361
+
362
+ all_frames = np.hstack(frames)
363
+ df['Frame'] = all_frames
364
+ sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
365
+ all_sources = np.hstack(sources)
366
+ df['Source'] = all_sources
367
+
368
+ if test_fnames:
369
+ t_split = lambda x: True if x in test_fnames else False
370
+ test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
371
+ else:
372
+ test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
373
+ all_test = np.hstack(test)
374
+ df['Test'] = all_test
375
+
376
+ if views:
377
+ view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
378
+ else:
379
+ view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
380
+ all_view = np.hstack(view)
381
+ df['View'] = all_view
382
+
383
+ if conditions:
384
+ condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
385
+ else:
386
+ condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
387
+ all_condition = np.hstack(condition)
388
+ df['Condition'] = all_condition
389
+ return df
390
+
391
+ def process_dataset_in_mem(embeddings_df: pd.DataFrame,
392
+ specified_classes=None,
393
+ classes_to_remove=None,
394
+ max_class_size=None,
395
+ animal_state=None,
396
+ view=None,
397
+ shuffle_data=False,
398
+ test_videos=None):
399
+ """
400
+ Processes output generated from embeddings paired with images and behavior labels.
401
+
402
+ Parameters:
403
+ -----------
404
+ csv_path : str
405
+ Path to the file containing the original data. This should contain embeddings,
406
+ a column named `'Label'` and a column named `'Images'`.
407
+ specified_classes : None | list[str]
408
+ An optional input. Defines labels which should be kept as is in the `'Label'`
409
+ column and which should be changed to a default `other` label.
410
+ classes_to_remove : None | list[str]
411
+ An optional input. Drops rows from the dataframe which contain a label in the
412
+ list.
413
+ max_class_size : None | int
414
+ An optional input. Determines the maximum amount of rows a single label can
415
+ appear in for each unique label in the `'Label'` column.
416
+ animal_state : None | str
417
+ An optional input. Drops rows from the dataframe which do not contain a match
418
+ for `animal_state` in the text field within the `'Images'` column.
419
+ view : None | str
420
+ An optional input. Drops rows from the dataframe which do not contain a match
421
+ for `view` in the text field within the `'Images'` column.
422
+ shuffle_data : bool
423
+ Determines wether the dataframe should have its rows shuffled.
424
+ test_videos : None | list[str]
425
+ An optional input. Determines what rows should be in the `test` dataframe, and
426
+ which should be in the `train` dataframe. It drops rows from the respective
427
+ dataframe by keeping or dropping rows which do not contain a match for a `str`
428
+ in `test_videos` in the text field within the `'Images'` column, respectively.
429
+
430
+ Returns:
431
+ --------
432
+ balanced_train_embeddings : pandas.DataFrame
433
+ A processed dataframe whose rows contain the embeddings for each of the images
434
+ at the corresponding index within `balanced_train_images`.
435
+ balanced_train_labels : list[str]
436
+ A list of labels for each of the images at the corresponing index within
437
+ `balanced_train_images`.
438
+ balanced_train_images: list[str]
439
+ A list of paths to images with each image at an index corresponding to a label
440
+ with the same index in `balanced_train_labels` and the same row index within
441
+ `balanced_train_embeddings`.
442
+ test_embeddings : pandas.DataFrame
443
+ A processed dataframe whose rows contain the embeddings for each of the images
444
+ at the corresponding index within `test_images`.
445
+ test_labels : list[str]
446
+ A list of labels for each of the images at the corresponing index within
447
+ `test_images`.
448
+ test_images : list[str]
449
+ A list of paths to images with each image at an index corresponding to a label
450
+ with the same index in `test_labels` and the same row index within
451
+ `test_embeddings`.
452
+ """
453
+ # Convert embeddings, labels, and images to a DataFrame for easy manipulation
454
+ df = copy.deepcopy(embeddings_df)
455
+ df_keys = [str(x) for x in df.keys()]
456
+ #Filter by fed or fasted
457
+ if 'Condition' in df_keys and animal_state:
458
+ df = df[df['Condition'].str.contains(animal_state, na=False)]
459
+
460
+ if 'View' in df_keys and view:
461
+ df = df[df['View'].str.contains(view, na=False)]
462
+
463
+ # Extract unique video names excluding the frame number
464
+ #unique_video_names = df['Images'].apply(lambda x: '_'.join(x.split('_')[:-1])).unique()
465
+ #print("\nUnique video names:\n", unique_video_names)
466
+
467
+ if classes_to_remove:
468
+ df = df[~df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
469
+ elif classes_to_remove and 'all' in classes_to_remove:
470
+ df = df[df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
471
+
472
+ # Further filter to include only specified_classes
473
+ if specified_classes:
474
+ single_match = lambda x: list(set(x.split('||')) & set(specified_classes))[0]
475
+ df['Label'] = df['Label'].apply(lambda x: single_match(x) if not set(x.split('||')).isdisjoint(specified_classes) else 'other')
476
+ specified_classes.append('other')
477
+
478
+ # Separate the DataFrame into test and training sets based on test_videos
479
+ if 'Test' in df_keys and test_videos:
480
+ test_df = df[df['Test']]
481
+ train_df = df[~df['Test']]
482
+ elif test_videos:
483
+ test_df = df[df['Images'].str.contains('|'.join(test_videos), na=False)]
484
+ train_df = df[~df['Images'].str.contains('|'.join(test_videos), na=False)]
485
+ else:
486
+ test_df = pd.DataFrame(columns=df.columns)
487
+ train_df = df
488
+
489
+ # Print the number of frames in each class before balancing
490
+ label_counts = train_df['Label'].value_counts()
491
+ print("\nNumber of training frames in each class before balancing:")
492
+ print(label_counts)
493
+
494
+ if max_class_size:
495
+ balanced_train_df = pd.concat([
496
+ group.sample(n=min(len(group), max_class_size), random_state=1)
497
+ for label, group in train_df.groupby('Label')
498
+ ])
499
+ else:
500
+ balanced_train_df = train_df
501
+
502
+ # Shuffle the training DataFrame
503
+ if shuffle_data:
504
+ balanced_train_df = balanced_train_df.sample(frac=1).reset_index(drop=True)
505
+
506
+ # Convert training set back to numpy array and list
507
+ if not "Images" in df_keys:
508
+ balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
509
+ balanced_train_labels = balanced_train_df['Label'].tolist()
510
+ balanced_train_images = balanced_train_df['Frame'].tolist()
511
+
512
+ # Convert test set back to numpy array and list
513
+ test_embeddings = test_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
514
+ test_labels = test_df['Label'].tolist()
515
+ test_images = test_df['Frame'].tolist()
516
+ else:
517
+ # Convert training set back to numpy array and list
518
+ balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Images']).to_numpy()
519
+ balanced_train_labels = balanced_train_df['Label'].tolist()
520
+ balanced_train_images = balanced_train_df['Images'].tolist()
521
+
522
+ # Convert test set back to numpy array and list
523
+ if 'Test' in test_df:
524
+ test_embeddings = test_df.drop(columns=['Label', 'Images', 'Test']).to_numpy()
525
+ else:
526
+ test_embeddings = test_df.drop(columns=['Label', 'Images']).to_numpy()
527
+
528
+ test_labels = test_df['Label'].tolist()
529
+ test_images = test_df['Images'].tolist()
530
+
531
+ # Print the number of frames in each class after balancing
532
+ if specified_classes or max_class_size:
533
+ balanced_label_counts = Counter(balanced_train_labels)
534
+ print("\nNumber of training frames in each class after balancing:")
535
+ print(balanced_label_counts)
536
+
537
+ test_label_counts = test_df['Label'].value_counts()
538
+ # print("\nNumber of testing frames in each class:")
539
+ print(test_label_counts)
540
+
541
+ return balanced_train_embeddings, balanced_train_labels, balanced_train_images, test_embeddings, test_labels, test_images
542
+
543
+ def multiclass_merge_and_filter_bouts(multiclass_vector, bout_threshold, proximity_threshold):
544
+ # Get the unique labels in the multiclass vector (excluding zero, assuming zero is the background/no label)
545
+ unique_labels = np.unique(multiclass_vector)
546
+ unique_labels = unique_labels[unique_labels != 0]
547
+
548
+ # Initialize a vector to store the merged and filtered multiclass vector
549
+ merged_vector = np.zeros_like(multiclass_vector)
550
+
551
+ for label in unique_labels:
552
+ # Create a binary vector for the current label
553
+ binary_vector = (multiclass_vector == label)
554
+
555
+ # Find the start and end indices of all sequences of 1's for this label
556
+ starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
557
+ ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]
558
+
559
+ # Step 1: Merge close short bouts
560
+ i = 0
561
+ while i < len(starts) - 1:
562
+ # Check if the gap between the end of the current bout and the start of the next bout
563
+ # is within the proximity threshold
564
+ if starts[i + 1] - ends[i] <= proximity_threshold:
565
+ # Merge the two bouts by setting all elements between the start of the first
566
+ # and the end of the second bout to 1
567
+ binary_vector[ends[i]:starts[i + 1]] = 1
568
+ # Remove the next bout from consideration
569
+ starts = np.delete(starts, i + 1)
570
+ ends = np.delete(ends, i)
571
+ else:
572
+ i += 1
573
+
574
+ # Update the starts and ends after merging
575
+ starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
576
+ ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]
577
+
578
+ # Step 2: Remove standalone short bouts
579
+ for i in range(len(starts)):
580
+ # Check the length of the bout
581
+ length_of_bout = ends[i] - starts[i] + 1
582
+
583
+ # If the length is less than the threshold, set those elements to 0
584
+ if length_of_bout < bout_threshold:
585
+ binary_vector[starts[i]:ends[i] + 1] = 0
586
+
587
+ # Combine the binary vector with the merged_vector, ensuring only the current label is set
588
+ merged_vector[binary_vector] = label
589
+
590
+ # Return the filtered multiclass vector
591
+ return merged_vector
592
+
593
+ def get_unique_labels(label_list: list[str]):
594
+ label_set = set()
595
+ for label in label_list:
596
+ individual_labels = label.split('||')
597
+ for individual_label in individual_labels:
598
+ label_set.add(individual_label)
599
+ return list(label_set)
600
+
601
+ def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42):
602
+ return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state)
603
+
604
+ def train_model(X_train, y_train, random_state=42):
605
+ # Train SVM Classifier
606
+ svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True)
607
+ svm_clf.fit(X_train, y_train)
608
+ return svm_clf
609
+
610
+ def pickle_model(model):
611
+ pickled = io.BytesIO()
612
+ pickle.dump(model, pickled)
613
+ return pickled
614
+
615
+ def get_seq_io_reader(uploaded_file):
616
+ assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
617
+ with NamedTemporaryFile(suffix="seq", delete=False) as temp:
618
+ temp.write(uploaded_file.getvalue())
619
+ sr = seqIo_reader(temp.name)
620
+ return sr
621
+
622
+ def seq_to_arr(sr):
623
+ N = sr.header['numFrames']
624
+ images = []
625
+ for f in range(N):
626
+ I, ts = sr.getFrame(f)
627
+ images.append(I)
628
+ return np.array(images)
629
+
630
+ def get_2d_embedding(embeddings: pd.DataFrame):
631
+ tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50)
632
+ embedding_2d = tsne.fit_transform(np.array(embeddings))
633
+ return embedding_2d
634
+
635
+
636
+