File size: 25,987 Bytes
3f0e895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1621356
 
 
 
 
 
 
3f0e895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
import os
import io
import pickle
import copy
from collections import Counter
from pathlib import Path
from tempfile import NamedTemporaryFile
import regex as re
import numpy as np
import pandas as pd
from sklearn.manifold import TSNE
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import torch
from tqdm import tqdm
from PIL import Image
from transformers import AutoProcessor, AutoModel
import streamlit as st
from .data_loading import load_multiple_annotations, load_multiple_annotations_io
from .data_processing import generate_label_array
from .seqIo import seqIo_reader
from .mp4Io import mp4Io_reader

SLIP_MODEL_ID = "google/siglip-so400m-patch14-384"
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"

def create_annot_fname_dict(annot_fnames: list[str])-> dict:
    fs = re.compile(r'.*(_\d+)$')

    unique_files = set()
    for file in annot_fnames:
        file_name = os.fsdecode(file)
        base_name, _ = os.path.splitext(file_name)
        if fs.match(base_name):
            ind = len(fs.match(base_name).group(1))
            unique_files.add(base_name[:-ind])
        else:
            unique_files.add(base_name)
    
    annot_fname_dict = {}
    for unique_file in unique_files:
        annot_fname_dict.update({unique_file: [file for file in annot_fnames if unique_file in file]})
    return annot_fname_dict

def create_annot_fname_dict_io(annot_fnames: list[str], annot_files: list)-> dict:
    annot_file_dict = {}
    for file in annot_files:
        annot_file_dict.update({file.name : file})
    fs = re.compile(r'.*(_\d+)$')

    unique_files = set()
    for file in annot_fnames:
        file_name = os.fsdecode(file)
        base_name, _ = os.path.splitext(file_name)
        if fs.match(base_name):
            ind = len(fs.match(base_name).group(1))
            unique_files.add(base_name[:-ind])
        else:
            unique_files.add(base_name)
    
    annot_fname_dict = {}
    for unique_file in unique_files:
        annot_list = [file for file in annot_fnames if unique_file in file]
        annot_list.sort()
        annot_file_list = [annot_file_dict[annot_file_name] for annot_file_name in annot_list]
        annot_fname_dict.update({unique_file: annot_file_list})
    return annot_fname_dict

def get_io_reader(uploaded_file):
    assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
    with NamedTemporaryFile(suffix="seq", delete=False) as temp:
        temp.write(uploaded_file.getvalue())
        sr = seqIo_reader(temp.name)
    return sr

def load_slip_model(device):
    return AutoModel.from_pretrained(SLIP_MODEL_ID).to(device)

def load_slip_preprocessor():
    return AutoProcessor.from_pretrained(SLIP_MODEL_ID)

def load_clip_model(device):
    return AutoModel.from_pretrained(CLIP_MODEL_ID).to(device)

def load_clip_preprocessor():
    return AutoProcessor.from_pretrained(CLIP_MODEL_ID)

def encode_image(image, device, model, processor):
    with torch.no_grad():
        #convert_models_to_fp32(model)
        inputs = processor(images=image, return_tensors="pt").to(device)
        image_features = model.get_image_features(**inputs)
    return image_features.cpu().numpy().flatten()

def generate_embeddings_stream(fnames : list[str],
                        model = 'SLIP',
                        downsample_rate = 4,
                        save_csv = False)-> tuple[list, list, list]:
    # set up model and device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    if model == 'SLIP':
        embed_model = load_slip_model(device)
        processor = load_slip_preprocessor()
    elif model == 'CLIP':
        embed_model = load_clip_model(device)
        processor = load_clip_preprocessor()

    all_video_embeddings = []
    all_video_frames = []
    for fname in fnames:
        # read in file
        is_seq = False
        if fname[-3:] == 'seq': is_seq = True
        
        if is_seq:
            sr = seqIo_reader(fname)
        else:
            sr = mp4Io_reader(fname)
        N  = sr.header['numFrames']

        # set up embeddings and frame arrays
        embeddings = []
        frames = list(range(N))[::downsample_rate]
        print(frames)

        # create progress bar
        i = 0
        pbar_text = lambda i: f'Creating embeddings for {fname}. {i}/{len(frames)} frames.'
        pbar = st.progress(0, text=pbar_text(0))

        # convert each frame to embeddings
        for f in tqdm(frames):
            img, _ = sr.getFrame(f)
            img_arr = np.array(img)
            if is_seq:
                img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
            else:
                img_rgb = Image.fromarray(img_arr).convert('RGB')

            embeddings.append(encode_image(img_rgb, device, embed_model, processor))

            # update progress bar
            i += 1
            pbar.progress(i/len(frames), pbar_text(i))

        # save csv of single file
        if save_csv:
            df = pd.DataFrame(embeddings)
            df['Frame'] = frames

            # save csv
            basename = Path(fname).stem
            df.to_csv(f'{basename}_embeddings_downsample_{downsample_rate}.csv', index=False)

        all_video_embeddings.append(np.array(embeddings))
        all_video_frames.append(frames)
    return all_video_embeddings, all_video_frames

def get_io_reader(uploaded_file):
    if uploaded_file.name[-3:]=='seq':
        with NamedTemporaryFile(suffix="seq", delete=False) as temp:
            temp.write(uploaded_file.getvalue())
            sr = seqIo_reader(temp.name)
    else:
        with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
            temp.write(uploaded_file.getvalue())
            sr = mp4Io_reader(temp.name)
    return sr

def generate_embeddings_stream_io(uploaded_files : list,
                                model = 'SLIP',
                                downsample_rate = 4,
                                save_csv = False)-> tuple[list, list, list]:
    # set up model and device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    with st.spinner('Loading multimodal model...'):
        if model == 'SLIP':
            embed_model = load_slip_model(device)
            processor = load_slip_preprocessor()
        elif model == 'CLIP':
            embed_model = load_clip_model(device)
            processor = load_clip_preprocessor()

    all_video_embeddings = []
    all_video_frames = []
    for file in uploaded_files:
        is_seq = False
        if file.name[-3:] == 'seq': is_seq = True

        # read in file
        sr = get_io_reader(file)
        N  = sr.header['numFrames']

        # set up embeddings and frame arrays
        embeddings = []
        frames = list(range(N))[::downsample_rate]
        print(frames)

        # create progress bar
        i = 0
        pbar_text = lambda i: f'Creating embeddings for {file.name}. {i}/{len(frames)} frames.'
        pbar = st.progress(0, text=pbar_text(0))

        # convert each frame to embeddings
        for f in tqdm(frames):
            img, _ = sr.getFrame(f)
            img_arr = np.array(img)
            if is_seq:
                img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
            else:
                img_rgb = Image.fromarray(img_arr).convert('RGB')

            embeddings.append(encode_image(img_rgb, device, embed_model, processor))

            # update progress bar
            i += 1
            pbar.progress(i/len(frames), pbar_text(i))

        # save csv of single file
        if save_csv:
            df = pd.DataFrame(embeddings)
            df['Frame'] = frames

            # save csv
            df.to_csv(f'embeddings_downsample_{downsample_rate}_{N}_frames.csv', index=False)

        all_video_embeddings.append(np.array(embeddings))
        all_video_frames.append(frames)
    return all_video_embeddings, all_video_frames

def create_embeddings_csv(out: str,
                          fnames: list[str],
                          embeddings: list[np.ndarray],
                          frames: list[list[int]],
                          annotations: list[list[str]],
                          test_fnames: None | list[str],
                          views: None | list[str],
                          conditions: None | list[str],
                          downsample_rate = 4,
                          filesystem = None):
    """
    Creates a .csv file containing all of the generated embeddings and provived information.

    Parameters:
    -----------
    out : str
        The name of the resulting file.
    fnames : list[str]
        Video sources for each of the embedding arrays.
    embeddings : np.ndarray
        The generated embeddings from the images.
    downsample_rate : int
        The downsample_rate used for generating the embeddings.
    """
    assert len(fnames) == len(embeddings)
    assert len(embeddings) == len(frames)
    all_embeddings = np.vstack(embeddings)
    df = pd.DataFrame(all_embeddings)
    
    labels = []
    for i, annot_fnames in enumerate(annotations):
        _, ext = os.path.splitext(annot_fnames[0])
        if ext == '.annot':
            annot, _, _, sr = load_multiple_annotations(annot_fnames, filesystem=filesystem)
            annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
        elif ext == '.csv':
            if not filesystem: 
                annot_df = pd.read_csv(annot_fnames[0], header=None)
            else:
                with filesystem.open(annot_fnames[0], 'r') as csv_file:
                    annot_df = pd.read_csv(csv_file, header=None)
            annot_labels = annot_df[0].to_list()[::downsample_rate]
            assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
        else:
            raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
        assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
        print(annot_labels)
        labels.append(annot_labels)
    all_labels = np.hstack(labels)
    print(len(all_labels))
    df['Label'] = all_labels
    
    all_frames = np.hstack(frames)
    df['Frame'] = all_frames
    sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
    all_sources = np.hstack(sources)
    df['Source'] = all_sources

    if test_fnames:
        t_split = lambda x: True if x in test_fnames else False
        test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
    else:
        test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
    all_test = np.hstack(test)
    df['Test'] = all_test

    if views:
        view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
    else:
        view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
    all_view = np.hstack(view)
    df['View'] = all_view
    
    if conditions:
        condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
    else:
        condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
    all_condition = np.hstack(condition)
    df['Condition'] = all_condition
    return df

def create_embeddings_csv_io(out: str,
                          fnames: list[str],
                          embeddings: list[np.ndarray],
                          frames: list[list[int]],
                          annotations: list,
                          test_fnames: None | list[str],
                          views: None | list[str],
                          conditions: None | list[str],
                          downsample_rate = 4):
    """
    Creates a .csv file containing all of the generated embeddings and provived information.

    Parameters:
    -----------
    out : str
        The name of the resulting file.
    fnames : list[str]
        Video sources for each of the embedding arrays.
    embeddings : np.ndarray
        The generated embeddings from the images.
    downsample_rate : int
        The downsample_rate used for generating the embeddings.
    """
    assert len(fnames) == len(embeddings)
    assert len(embeddings) == len(frames)
    all_embeddings = np.vstack(embeddings)
    df = pd.DataFrame(all_embeddings)
    
    labels = []
    for i, uploaded_annots in enumerate(annotations):
        print(i)
        _, ext = os.path.splitext(uploaded_annots[0].name)
        if ext == '.annot':
            annot, _, _, sr = load_multiple_annotations_io(uploaded_annots)
            annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
        elif ext == '.csv':
            annot_df = pd.read_csv(uploaded_annots[0], header=None)
            annot_labels = annot_df[0].to_list()[::downsample_rate]
            assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
        else:
            raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
        assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
        print(annot_labels)
        labels.append(annot_labels)
    all_labels = np.hstack(labels)
    print(len(all_labels))
    df['Label'] = all_labels
    
    all_frames = np.hstack(frames)
    df['Frame'] = all_frames
    sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
    all_sources = np.hstack(sources)
    df['Source'] = all_sources

    if test_fnames:
        t_split = lambda x: True if x in test_fnames else False
        test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
    else:
        test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
    all_test = np.hstack(test)
    df['Test'] = all_test

    if views:
        view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
    else:
        view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
    all_view = np.hstack(view)
    df['View'] = all_view
    
    if conditions:
        condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
    else:
        condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
    all_condition = np.hstack(condition)
    df['Condition'] = all_condition
    return df

def process_dataset_in_mem(embeddings_df: pd.DataFrame,
                    specified_classes=None,
                    classes_to_remove=None,
                    max_class_size=None,
                    animal_state=None,
                    view=None,
                    shuffle_data=False,
                    test_videos=None):
    """
    Processes output generated from embeddings paired with images and behavior labels.

    Parameters:
    -----------
    csv_path : str
        Path to the file containing the original data. This should contain embeddings,
        a column named `'Label'` and a column named `'Images'`.
    specified_classes : None | list[str]
        An optional input. Defines labels which should be kept as is in the `'Label'`
        column and which should be changed to a default `other` label.
    classes_to_remove : None | list[str]
        An optional input. Drops rows from the dataframe which contain a label in the
        list.
    max_class_size : None | int
        An optional input. Determines the maximum amount of rows a single label can
        appear in for each unique label in the `'Label'` column.
    animal_state : None | str
        An optional input. Drops rows from the dataframe which do not contain a match
        for `animal_state` in the text field within the `'Images'` column.
    view : None | str
        An optional input. Drops rows from the dataframe which do not contain a match
        for `view` in the text field within the `'Images'` column.
    shuffle_data : bool
        Determines wether the dataframe should have its rows shuffled.
    test_videos : None | list[str]
        An optional input. Determines what rows should be in the `test` dataframe, and
        which should be in the `train` dataframe. It drops rows from the respective
        dataframe by keeping or dropping rows which do not contain a match for a `str`
        in `test_videos` in the text field within the `'Images'` column, respectively.
    
    Returns:
    --------
    balanced_train_embeddings : pandas.DataFrame
        A processed dataframe whose rows contain the embeddings for each of the images
        at the corresponding index within `balanced_train_images`.
    balanced_train_labels : list[str]
        A list of labels for each of the images at the corresponing index within
        `balanced_train_images`.
    balanced_train_images: list[str]
        A list of paths to images with each image at an index corresponding to a label
        with the same index in `balanced_train_labels` and the same row index within
        `balanced_train_embeddings`.
    test_embeddings : pandas.DataFrame
        A processed dataframe whose rows contain the embeddings for each of the images
        at the corresponding index within `test_images`.
    test_labels : list[str]
        A list of labels for each of the images at the corresponing index within
        `test_images`.
    test_images : list[str]
        A list of paths to images with each image at an index corresponding to a label
        with the same index in `test_labels` and the same row index within
        `test_embeddings`.
    """
    # Convert embeddings, labels, and images to a DataFrame for easy manipulation
    df = copy.deepcopy(embeddings_df)
    df_keys = [str(x) for x in df.keys()]
    #Filter by fed or fasted
    if 'Condition' in df_keys and animal_state:
        df = df[df['Condition'].str.contains(animal_state, na=False)]

    if 'View' in df_keys and view:
        df = df[df['View'].str.contains(view, na=False)]

    # Extract unique video names excluding the frame number
    #unique_video_names = df['Images'].apply(lambda x: '_'.join(x.split('_')[:-1])).unique()
    #print("\nUnique video names:\n", unique_video_names)

    if classes_to_remove:
        df = df[~df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
    elif classes_to_remove and 'all' in classes_to_remove:
        df = df[df['Label'].str.contains('|'.join(classes_to_remove), na=False)]

    # Further filter to include only specified_classes
    if specified_classes:
        single_match = lambda x: list(set(x.split('||')) & set(specified_classes))[0]
        df['Label'] = df['Label'].apply(lambda x: single_match(x) if not set(x.split('||')).isdisjoint(specified_classes) else 'other')
        specified_classes.append('other')

    # Separate the DataFrame into test and training sets based on test_videos
    if 'Test' in df_keys and test_videos:
        test_df = df[df['Test']]
        train_df = df[~df['Test']]
    elif test_videos:
        test_df = df[df['Images'].str.contains('|'.join(test_videos), na=False)]
        train_df = df[~df['Images'].str.contains('|'.join(test_videos), na=False)]
    else:
        test_df = pd.DataFrame(columns=df.columns)
        train_df = df
    
    # Print the number of frames in each class before balancing
    label_counts = train_df['Label'].value_counts()
    print("\nNumber of training frames in each class before balancing:")
    print(label_counts)
    
    if max_class_size:
        balanced_train_df = pd.concat([
            group.sample(n=min(len(group), max_class_size), random_state=1)
            for label, group in train_df.groupby('Label')
        ])
    else:
        balanced_train_df = train_df

    # Shuffle the training DataFrame
    if shuffle_data:
        balanced_train_df = balanced_train_df.sample(frac=1).reset_index(drop=True)
    
    # Convert training set back to numpy array and list
    if not "Images" in df_keys:
        balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
        balanced_train_labels = balanced_train_df['Label'].tolist()
        balanced_train_images = balanced_train_df['Frame'].tolist()
        
        # Convert test set back to numpy array and list
        test_embeddings = test_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
        test_labels = test_df['Label'].tolist()
        test_images = test_df['Frame'].tolist()
    else:
        # Convert training set back to numpy array and list
        balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Images']).to_numpy()
        balanced_train_labels = balanced_train_df['Label'].tolist()
        balanced_train_images = balanced_train_df['Images'].tolist()
        
        # Convert test set back to numpy array and list
        if 'Test' in test_df:
            test_embeddings = test_df.drop(columns=['Label', 'Images', 'Test']).to_numpy()
        else:
            test_embeddings = test_df.drop(columns=['Label', 'Images']).to_numpy()

        test_labels = test_df['Label'].tolist()
        test_images = test_df['Images'].tolist()
    
    # Print the number of frames in each class after balancing
    if specified_classes or max_class_size:
        balanced_label_counts = Counter(balanced_train_labels)
        print("\nNumber of training frames in each class after balancing:")
        print(balanced_label_counts)

    test_label_counts = test_df['Label'].value_counts()
    # print("\nNumber of testing frames in each class:")
    print(test_label_counts)
    
    return balanced_train_embeddings, balanced_train_labels, balanced_train_images, test_embeddings, test_labels, test_images

def multiclass_merge_and_filter_bouts(multiclass_vector, bout_threshold, proximity_threshold):
    # Get the unique labels in the multiclass vector (excluding zero, assuming zero is the background/no label)
    unique_labels = np.unique(multiclass_vector)
    unique_labels = unique_labels[unique_labels != 0]

    # Initialize a vector to store the merged and filtered multiclass vector
    merged_vector = np.zeros_like(multiclass_vector)

    for label in unique_labels:
        # Create a binary vector for the current label
        binary_vector = (multiclass_vector == label)

        # Find the start and end indices of all sequences of 1's for this label
        starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
        ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]

        # Step 1: Merge close short bouts
        i = 0
        while i < len(starts) - 1:
            # Check if the gap between the end of the current bout and the start of the next bout
            # is within the proximity threshold
            if starts[i + 1] - ends[i] <= proximity_threshold:
                # Merge the two bouts by setting all elements between the start of the first
                # and the end of the second bout to 1
                binary_vector[ends[i]:starts[i + 1]] = 1
                # Remove the next bout from consideration
                starts = np.delete(starts, i + 1)
                ends = np.delete(ends, i)
            else:
                i += 1

        # Update the starts and ends after merging
        starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
        ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]

        # Step 2: Remove standalone short bouts
        for i in range(len(starts)):
            # Check the length of the bout
            length_of_bout = ends[i] - starts[i] + 1

            # If the length is less than the threshold, set those elements to 0
            if length_of_bout < bout_threshold:
                binary_vector[starts[i]:ends[i] + 1] = 0

        # Combine the binary vector with the merged_vector, ensuring only the current label is set
        merged_vector[binary_vector] = label

    # Return the filtered multiclass vector
    return merged_vector

def get_unique_labels(label_list: list[str]):
    label_set = set()
    for label in label_list:
        individual_labels = label.split('||')
        for individual_label in individual_labels:
            label_set.add(individual_label)
    return list(label_set)

def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42):
    return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state)

def train_model(X_train, y_train, random_state=42):
    # Train SVM Classifier
    svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True)
    svm_clf.fit(X_train, y_train)
    return svm_clf

def pickle_model(model):
    pickled = io.BytesIO()
    pickle.dump(model, pickled)
    return pickled

def get_seq_io_reader(uploaded_file):
    assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
    with NamedTemporaryFile(suffix="seq", delete=False) as temp:
        temp.write(uploaded_file.getvalue())
        sr = seqIo_reader(temp.name)
    return sr

def seq_to_arr(sr):
    N = sr.header['numFrames']
    images = []
    for f in range(N):
        I, ts = sr.getFrame(f)
        images.append(I)
    return np.array(images)

def get_2d_embedding(embeddings: pd.DataFrame):
    tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50)
    embedding_2d = tsne.fit_transform(np.array(embeddings))
    return embedding_2d