import os import io import pickle import copy from collections import Counter from pathlib import Path from tempfile import NamedTemporaryFile import regex as re import numpy as np import pandas as pd from sklearn.manifold import TSNE from sklearn.svm import SVC from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, classification_report import torch from tqdm import tqdm from PIL import Image from transformers import AutoProcessor, AutoModel import streamlit as st from .data_loading import load_multiple_annotations, load_multiple_annotations_io from .data_processing import generate_label_array from .seqIo import seqIo_reader from .mp4Io import mp4Io_reader SLIP_MODEL_ID = "google/siglip-so400m-patch14-384" CLIP_MODEL_ID = "openai/clip-vit-base-patch32" def create_annot_fname_dict(annot_fnames: list[str])-> dict: fs = re.compile(r'.*(_\d+)$') unique_files = set() for file in annot_fnames: file_name = os.fsdecode(file) base_name, _ = os.path.splitext(file_name) if fs.match(base_name): ind = len(fs.match(base_name).group(1)) unique_files.add(base_name[:-ind]) else: unique_files.add(base_name) annot_fname_dict = {} for unique_file in unique_files: annot_fname_dict.update({unique_file: [file for file in annot_fnames if unique_file in file]}) return annot_fname_dict def create_annot_fname_dict_io(annot_fnames: list[str], annot_files: list)-> dict: annot_file_dict = {} for file in annot_files: annot_file_dict.update({file.name : file}) fs = re.compile(r'.*(_\d+)$') unique_files = set() for file in annot_fnames: file_name = os.fsdecode(file) base_name, _ = os.path.splitext(file_name) if fs.match(base_name): ind = len(fs.match(base_name).group(1)) unique_files.add(base_name[:-ind]) else: unique_files.add(base_name) annot_fname_dict = {} for unique_file in unique_files: annot_list = [file for file in annot_fnames if unique_file in file] annot_list.sort() annot_file_list = [annot_file_dict[annot_file_name] for annot_file_name in annot_list] annot_fname_dict.update({unique_file: annot_file_list}) return annot_fname_dict def get_io_reader(uploaded_file): assert uploaded_file.name[-3:]=='seq', 'Not a seq file' with NamedTemporaryFile(suffix="seq", delete=False) as temp: temp.write(uploaded_file.getvalue()) sr = seqIo_reader(temp.name) return sr def load_slip_model(device): return AutoModel.from_pretrained(SLIP_MODEL_ID).to(device) def load_slip_preprocessor(): return AutoProcessor.from_pretrained(SLIP_MODEL_ID) def load_clip_model(device): return AutoModel.from_pretrained(CLIP_MODEL_ID).to(device) def load_clip_preprocessor(): return AutoProcessor.from_pretrained(CLIP_MODEL_ID) def encode_image(image, device, model, processor): with torch.no_grad(): #convert_models_to_fp32(model) inputs = processor(images=image, return_tensors="pt").to(device) image_features = model.get_image_features(**inputs) return image_features.cpu().numpy().flatten() def generate_embeddings_stream(fnames : list[str], model = 'SLIP', downsample_rate = 4, save_csv = False)-> tuple[list, list, list]: # set up model and device device = "cuda" if torch.cuda.is_available() else "cpu" os.environ['CUDA_VISIBLE_DEVICES'] = '0' if model == 'SLIP': embed_model = load_slip_model(device) processor = load_slip_preprocessor() elif model == 'CLIP': embed_model = load_clip_model(device) processor = load_clip_preprocessor() all_video_embeddings = [] all_video_frames = [] for fname in fnames: # read in file is_seq = False if fname[-3:] == 'seq': is_seq = True if is_seq: sr = seqIo_reader(fname) else: sr = mp4Io_reader(fname) N = sr.header['numFrames'] # set up embeddings and frame arrays embeddings = [] frames = list(range(N))[::downsample_rate] print(frames) # create progress bar i = 0 pbar_text = lambda i: f'Creating embeddings for {fname}. {i}/{len(frames)} frames.' pbar = st.progress(0, text=pbar_text(0)) # convert each frame to embeddings for f in tqdm(frames): img, _ = sr.getFrame(f) img_arr = np.array(img) if is_seq: img_rgb = Image.fromarray(img_arr, 'L').convert('RGB') else: img_rgb = Image.fromarray(img_arr).convert('RGB') embeddings.append(encode_image(img_rgb, device, embed_model, processor)) # update progress bar i += 1 pbar.progress(i/len(frames), pbar_text(i)) # save csv of single file if save_csv: df = pd.DataFrame(embeddings) df['Frame'] = frames # save csv basename = Path(fname).stem df.to_csv(f'{basename}_embeddings_downsample_{downsample_rate}.csv', index=False) all_video_embeddings.append(np.array(embeddings)) all_video_frames.append(frames) return all_video_embeddings, all_video_frames def get_io_reader(uploaded_file): if uploaded_file.name[-3:]=='seq': with NamedTemporaryFile(suffix="seq", delete=False) as temp: temp.write(uploaded_file.getvalue()) sr = seqIo_reader(temp.name) else: with NamedTemporaryFile(suffix="mp4", delete=False) as temp: temp.write(uploaded_file.getvalue()) sr = mp4Io_reader(temp.name) return sr def generate_embeddings_stream_io(uploaded_files : list, model = 'SLIP', downsample_rate = 4, save_csv = False)-> tuple[list, list, list]: # set up model and device device = "cuda" if torch.cuda.is_available() else "cpu" os.environ['CUDA_VISIBLE_DEVICES'] = '0' with st.spinner('Loading multimodal model...'): if model == 'SLIP': embed_model = load_slip_model(device) processor = load_slip_preprocessor() elif model == 'CLIP': embed_model = load_clip_model(device) processor = load_clip_preprocessor() all_video_embeddings = [] all_video_frames = [] for file in uploaded_files: is_seq = False if file.name[-3:] == 'seq': is_seq = True # read in file sr = get_io_reader(file) N = sr.header['numFrames'] # set up embeddings and frame arrays embeddings = [] frames = list(range(N))[::downsample_rate] print(frames) # create progress bar i = 0 pbar_text = lambda i: f'Creating embeddings for {file.name}. {i}/{len(frames)} frames.' pbar = st.progress(0, text=pbar_text(0)) # convert each frame to embeddings for f in tqdm(frames): img, _ = sr.getFrame(f) img_arr = np.array(img) if is_seq: img_rgb = Image.fromarray(img_arr, 'L').convert('RGB') else: img_rgb = Image.fromarray(img_arr).convert('RGB') embeddings.append(encode_image(img_rgb, device, embed_model, processor)) # update progress bar i += 1 pbar.progress(i/len(frames), pbar_text(i)) # save csv of single file if save_csv: df = pd.DataFrame(embeddings) df['Frame'] = frames # save csv df.to_csv(f'embeddings_downsample_{downsample_rate}_{N}_frames.csv', index=False) all_video_embeddings.append(np.array(embeddings)) all_video_frames.append(frames) return all_video_embeddings, all_video_frames def create_embeddings_csv(out: str, fnames: list[str], embeddings: list[np.ndarray], frames: list[list[int]], annotations: list[list[str]], test_fnames: None | list[str], views: None | list[str], conditions: None | list[str], downsample_rate = 4, filesystem = None): """ Creates a .csv file containing all of the generated embeddings and provived information. Parameters: ----------- out : str The name of the resulting file. fnames : list[str] Video sources for each of the embedding arrays. embeddings : np.ndarray The generated embeddings from the images. downsample_rate : int The downsample_rate used for generating the embeddings. """ assert len(fnames) == len(embeddings) assert len(embeddings) == len(frames) all_embeddings = np.vstack(embeddings) df = pd.DataFrame(all_embeddings) labels = [] for i, annot_fnames in enumerate(annotations): _, ext = os.path.splitext(annot_fnames[0]) if ext == '.annot': annot, _, _, sr = load_multiple_annotations(annot_fnames, filesystem=filesystem) annot_labels = generate_label_array(annot, downsample_rate, len(frames[i])) elif ext == '.csv': if not filesystem: annot_df = pd.read_csv(annot_fnames[0], header=None) else: with filesystem.open(annot_fnames[0], 'r') as csv_file: annot_df = pd.read_csv(csv_file, header=None) annot_labels = annot_df[0].to_list()[::downsample_rate] assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header." else: raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".') assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files." print(annot_labels) labels.append(annot_labels) all_labels = np.hstack(labels) print(len(all_labels)) df['Label'] = all_labels all_frames = np.hstack(frames) df['Frame'] = all_frames sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)] all_sources = np.hstack(sources) df['Source'] = all_sources if test_fnames: t_split = lambda x: True if x in test_fnames else False test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)] else: test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)] all_test = np.hstack(test) df['Test'] = all_test if views: view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))] else: view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))] all_view = np.hstack(view) df['View'] = all_view if conditions: condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))] else: condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))] all_condition = np.hstack(condition) df['Condition'] = all_condition return df def create_embeddings_csv_io(out: str, fnames: list[str], embeddings: list[np.ndarray], frames: list[list[int]], annotations: list, test_fnames: None | list[str], views: None | list[str], conditions: None | list[str], downsample_rate = 4): """ Creates a .csv file containing all of the generated embeddings and provived information. Parameters: ----------- out : str The name of the resulting file. fnames : list[str] Video sources for each of the embedding arrays. embeddings : np.ndarray The generated embeddings from the images. downsample_rate : int The downsample_rate used for generating the embeddings. """ assert len(fnames) == len(embeddings) assert len(embeddings) == len(frames) all_embeddings = np.vstack(embeddings) df = pd.DataFrame(all_embeddings) labels = [] for i, uploaded_annots in enumerate(annotations): print(i) _, ext = os.path.splitext(uploaded_annots[0].name) if ext == '.annot': annot, _, _, sr = load_multiple_annotations_io(uploaded_annots) annot_labels = generate_label_array(annot, downsample_rate, len(frames[i])) elif ext == '.csv': annot_df = pd.read_csv(uploaded_annots[0], header=None) annot_labels = annot_df[0].to_list()[::downsample_rate] assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header." else: raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".') assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files." print(annot_labels) labels.append(annot_labels) all_labels = np.hstack(labels) print(len(all_labels)) df['Label'] = all_labels all_frames = np.hstack(frames) df['Frame'] = all_frames sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)] all_sources = np.hstack(sources) df['Source'] = all_sources if test_fnames: t_split = lambda x: True if x in test_fnames else False test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)] else: test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)] all_test = np.hstack(test) df['Test'] = all_test if views: view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))] else: view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))] all_view = np.hstack(view) df['View'] = all_view if conditions: condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))] else: condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))] all_condition = np.hstack(condition) df['Condition'] = all_condition return df def process_dataset_in_mem(embeddings_df: pd.DataFrame, specified_classes=None, classes_to_remove=None, max_class_size=None, animal_state=None, view=None, shuffle_data=False, test_videos=None): """ Processes output generated from embeddings paired with images and behavior labels. Parameters: ----------- csv_path : str Path to the file containing the original data. This should contain embeddings, a column named `'Label'` and a column named `'Images'`. specified_classes : None | list[str] An optional input. Defines labels which should be kept as is in the `'Label'` column and which should be changed to a default `other` label. classes_to_remove : None | list[str] An optional input. Drops rows from the dataframe which contain a label in the list. max_class_size : None | int An optional input. Determines the maximum amount of rows a single label can appear in for each unique label in the `'Label'` column. animal_state : None | str An optional input. Drops rows from the dataframe which do not contain a match for `animal_state` in the text field within the `'Images'` column. view : None | str An optional input. Drops rows from the dataframe which do not contain a match for `view` in the text field within the `'Images'` column. shuffle_data : bool Determines wether the dataframe should have its rows shuffled. test_videos : None | list[str] An optional input. Determines what rows should be in the `test` dataframe, and which should be in the `train` dataframe. It drops rows from the respective dataframe by keeping or dropping rows which do not contain a match for a `str` in `test_videos` in the text field within the `'Images'` column, respectively. Returns: -------- balanced_train_embeddings : pandas.DataFrame A processed dataframe whose rows contain the embeddings for each of the images at the corresponding index within `balanced_train_images`. balanced_train_labels : list[str] A list of labels for each of the images at the corresponing index within `balanced_train_images`. balanced_train_images: list[str] A list of paths to images with each image at an index corresponding to a label with the same index in `balanced_train_labels` and the same row index within `balanced_train_embeddings`. test_embeddings : pandas.DataFrame A processed dataframe whose rows contain the embeddings for each of the images at the corresponding index within `test_images`. test_labels : list[str] A list of labels for each of the images at the corresponing index within `test_images`. test_images : list[str] A list of paths to images with each image at an index corresponding to a label with the same index in `test_labels` and the same row index within `test_embeddings`. """ # Convert embeddings, labels, and images to a DataFrame for easy manipulation df = copy.deepcopy(embeddings_df) df_keys = [str(x) for x in df.keys()] #Filter by fed or fasted if 'Condition' in df_keys and animal_state: df = df[df['Condition'].str.contains(animal_state, na=False)] if 'View' in df_keys and view: df = df[df['View'].str.contains(view, na=False)] # Extract unique video names excluding the frame number #unique_video_names = df['Images'].apply(lambda x: '_'.join(x.split('_')[:-1])).unique() #print("\nUnique video names:\n", unique_video_names) if classes_to_remove: df = df[~df['Label'].str.contains('|'.join(classes_to_remove), na=False)] elif classes_to_remove and 'all' in classes_to_remove: df = df[df['Label'].str.contains('|'.join(classes_to_remove), na=False)] # Further filter to include only specified_classes if specified_classes: single_match = lambda x: list(set(x.split('||')) & set(specified_classes))[0] df['Label'] = df['Label'].apply(lambda x: single_match(x) if not set(x.split('||')).isdisjoint(specified_classes) else 'other') specified_classes.append('other') # Separate the DataFrame into test and training sets based on test_videos if 'Test' in df_keys and test_videos: test_df = df[df['Test']] train_df = df[~df['Test']] elif test_videos: test_df = df[df['Images'].str.contains('|'.join(test_videos), na=False)] train_df = df[~df['Images'].str.contains('|'.join(test_videos), na=False)] else: test_df = pd.DataFrame(columns=df.columns) train_df = df # Print the number of frames in each class before balancing label_counts = train_df['Label'].value_counts() print("\nNumber of training frames in each class before balancing:") print(label_counts) if max_class_size: balanced_train_df = pd.concat([ group.sample(n=min(len(group), max_class_size), random_state=1) for label, group in train_df.groupby('Label') ]) else: balanced_train_df = train_df # Shuffle the training DataFrame if shuffle_data: balanced_train_df = balanced_train_df.sample(frac=1).reset_index(drop=True) # Convert training set back to numpy array and list if not "Images" in df_keys: balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy() balanced_train_labels = balanced_train_df['Label'].tolist() balanced_train_images = balanced_train_df['Frame'].tolist() # Convert test set back to numpy array and list test_embeddings = test_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy() test_labels = test_df['Label'].tolist() test_images = test_df['Frame'].tolist() else: # Convert training set back to numpy array and list balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Images']).to_numpy() balanced_train_labels = balanced_train_df['Label'].tolist() balanced_train_images = balanced_train_df['Images'].tolist() # Convert test set back to numpy array and list if 'Test' in test_df: test_embeddings = test_df.drop(columns=['Label', 'Images', 'Test']).to_numpy() else: test_embeddings = test_df.drop(columns=['Label', 'Images']).to_numpy() test_labels = test_df['Label'].tolist() test_images = test_df['Images'].tolist() # Print the number of frames in each class after balancing if specified_classes or max_class_size: balanced_label_counts = Counter(balanced_train_labels) print("\nNumber of training frames in each class after balancing:") print(balanced_label_counts) test_label_counts = test_df['Label'].value_counts() # print("\nNumber of testing frames in each class:") print(test_label_counts) return balanced_train_embeddings, balanced_train_labels, balanced_train_images, test_embeddings, test_labels, test_images def multiclass_merge_and_filter_bouts(multiclass_vector, bout_threshold, proximity_threshold): # Get the unique labels in the multiclass vector (excluding zero, assuming zero is the background/no label) unique_labels = np.unique(multiclass_vector) unique_labels = unique_labels[unique_labels != 0] # Initialize a vector to store the merged and filtered multiclass vector merged_vector = np.zeros_like(multiclass_vector) for label in unique_labels: # Create a binary vector for the current label binary_vector = (multiclass_vector == label) # Find the start and end indices of all sequences of 1's for this label starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0] ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0] # Step 1: Merge close short bouts i = 0 while i < len(starts) - 1: # Check if the gap between the end of the current bout and the start of the next bout # is within the proximity threshold if starts[i + 1] - ends[i] <= proximity_threshold: # Merge the two bouts by setting all elements between the start of the first # and the end of the second bout to 1 binary_vector[ends[i]:starts[i + 1]] = 1 # Remove the next bout from consideration starts = np.delete(starts, i + 1) ends = np.delete(ends, i) else: i += 1 # Update the starts and ends after merging starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0] ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0] # Step 2: Remove standalone short bouts for i in range(len(starts)): # Check the length of the bout length_of_bout = ends[i] - starts[i] + 1 # If the length is less than the threshold, set those elements to 0 if length_of_bout < bout_threshold: binary_vector[starts[i]:ends[i] + 1] = 0 # Combine the binary vector with the merged_vector, ensuring only the current label is set merged_vector[binary_vector] = label # Return the filtered multiclass vector return merged_vector def get_unique_labels(label_list: list[str]): label_set = set() for label in label_list: individual_labels = label.split('||') for individual_label in individual_labels: label_set.add(individual_label) return list(label_set) def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42): return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state) def train_model(X_train, y_train, random_state=42): # Train SVM Classifier svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True) svm_clf.fit(X_train, y_train) return svm_clf def pickle_model(model): pickled = io.BytesIO() pickle.dump(model, pickled) return pickled def get_seq_io_reader(uploaded_file): assert uploaded_file.name[-3:]=='seq', 'Not a seq file' with NamedTemporaryFile(suffix="seq", delete=False) as temp: temp.write(uploaded_file.getvalue()) sr = seqIo_reader(temp.name) return sr def seq_to_arr(sr): N = sr.header['numFrames'] images = [] for f in range(N): I, ts = sr.getFrame(f) images.append(I) return np.array(images) def get_2d_embedding(embeddings: pd.DataFrame): tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50) embedding_2d = tsne.fit_transform(np.array(embeddings)) return embedding_2d