batik / utils /utils.py
ncoria's picture
add spinner to load slip/clip
1621356 verified
import os
import io
import pickle
import copy
from collections import Counter
from pathlib import Path
from tempfile import NamedTemporaryFile
import regex as re
import numpy as np
import pandas as pd
from sklearn.manifold import TSNE
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import torch
from tqdm import tqdm
from PIL import Image
from transformers import AutoProcessor, AutoModel
import streamlit as st
from .data_loading import load_multiple_annotations, load_multiple_annotations_io
from .data_processing import generate_label_array
from .seqIo import seqIo_reader
from .mp4Io import mp4Io_reader
SLIP_MODEL_ID = "google/siglip-so400m-patch14-384"
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
def create_annot_fname_dict(annot_fnames: list[str])-> dict:
fs = re.compile(r'.*(_\d+)$')
unique_files = set()
for file in annot_fnames:
file_name = os.fsdecode(file)
base_name, _ = os.path.splitext(file_name)
if fs.match(base_name):
ind = len(fs.match(base_name).group(1))
unique_files.add(base_name[:-ind])
else:
unique_files.add(base_name)
annot_fname_dict = {}
for unique_file in unique_files:
annot_fname_dict.update({unique_file: [file for file in annot_fnames if unique_file in file]})
return annot_fname_dict
def create_annot_fname_dict_io(annot_fnames: list[str], annot_files: list)-> dict:
annot_file_dict = {}
for file in annot_files:
annot_file_dict.update({file.name : file})
fs = re.compile(r'.*(_\d+)$')
unique_files = set()
for file in annot_fnames:
file_name = os.fsdecode(file)
base_name, _ = os.path.splitext(file_name)
if fs.match(base_name):
ind = len(fs.match(base_name).group(1))
unique_files.add(base_name[:-ind])
else:
unique_files.add(base_name)
annot_fname_dict = {}
for unique_file in unique_files:
annot_list = [file for file in annot_fnames if unique_file in file]
annot_list.sort()
annot_file_list = [annot_file_dict[annot_file_name] for annot_file_name in annot_list]
annot_fname_dict.update({unique_file: annot_file_list})
return annot_fname_dict
def get_io_reader(uploaded_file):
assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
with NamedTemporaryFile(suffix="seq", delete=False) as temp:
temp.write(uploaded_file.getvalue())
sr = seqIo_reader(temp.name)
return sr
def load_slip_model(device):
return AutoModel.from_pretrained(SLIP_MODEL_ID).to(device)
def load_slip_preprocessor():
return AutoProcessor.from_pretrained(SLIP_MODEL_ID)
def load_clip_model(device):
return AutoModel.from_pretrained(CLIP_MODEL_ID).to(device)
def load_clip_preprocessor():
return AutoProcessor.from_pretrained(CLIP_MODEL_ID)
def encode_image(image, device, model, processor):
with torch.no_grad():
#convert_models_to_fp32(model)
inputs = processor(images=image, return_tensors="pt").to(device)
image_features = model.get_image_features(**inputs)
return image_features.cpu().numpy().flatten()
def generate_embeddings_stream(fnames : list[str],
model = 'SLIP',
downsample_rate = 4,
save_csv = False)-> tuple[list, list, list]:
# set up model and device
device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
if model == 'SLIP':
embed_model = load_slip_model(device)
processor = load_slip_preprocessor()
elif model == 'CLIP':
embed_model = load_clip_model(device)
processor = load_clip_preprocessor()
all_video_embeddings = []
all_video_frames = []
for fname in fnames:
# read in file
is_seq = False
if fname[-3:] == 'seq': is_seq = True
if is_seq:
sr = seqIo_reader(fname)
else:
sr = mp4Io_reader(fname)
N = sr.header['numFrames']
# set up embeddings and frame arrays
embeddings = []
frames = list(range(N))[::downsample_rate]
print(frames)
# create progress bar
i = 0
pbar_text = lambda i: f'Creating embeddings for {fname}. {i}/{len(frames)} frames.'
pbar = st.progress(0, text=pbar_text(0))
# convert each frame to embeddings
for f in tqdm(frames):
img, _ = sr.getFrame(f)
img_arr = np.array(img)
if is_seq:
img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
else:
img_rgb = Image.fromarray(img_arr).convert('RGB')
embeddings.append(encode_image(img_rgb, device, embed_model, processor))
# update progress bar
i += 1
pbar.progress(i/len(frames), pbar_text(i))
# save csv of single file
if save_csv:
df = pd.DataFrame(embeddings)
df['Frame'] = frames
# save csv
basename = Path(fname).stem
df.to_csv(f'{basename}_embeddings_downsample_{downsample_rate}.csv', index=False)
all_video_embeddings.append(np.array(embeddings))
all_video_frames.append(frames)
return all_video_embeddings, all_video_frames
def get_io_reader(uploaded_file):
if uploaded_file.name[-3:]=='seq':
with NamedTemporaryFile(suffix="seq", delete=False) as temp:
temp.write(uploaded_file.getvalue())
sr = seqIo_reader(temp.name)
else:
with NamedTemporaryFile(suffix="mp4", delete=False) as temp:
temp.write(uploaded_file.getvalue())
sr = mp4Io_reader(temp.name)
return sr
def generate_embeddings_stream_io(uploaded_files : list,
model = 'SLIP',
downsample_rate = 4,
save_csv = False)-> tuple[list, list, list]:
# set up model and device
device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
with st.spinner('Loading multimodal model...'):
if model == 'SLIP':
embed_model = load_slip_model(device)
processor = load_slip_preprocessor()
elif model == 'CLIP':
embed_model = load_clip_model(device)
processor = load_clip_preprocessor()
all_video_embeddings = []
all_video_frames = []
for file in uploaded_files:
is_seq = False
if file.name[-3:] == 'seq': is_seq = True
# read in file
sr = get_io_reader(file)
N = sr.header['numFrames']
# set up embeddings and frame arrays
embeddings = []
frames = list(range(N))[::downsample_rate]
print(frames)
# create progress bar
i = 0
pbar_text = lambda i: f'Creating embeddings for {file.name}. {i}/{len(frames)} frames.'
pbar = st.progress(0, text=pbar_text(0))
# convert each frame to embeddings
for f in tqdm(frames):
img, _ = sr.getFrame(f)
img_arr = np.array(img)
if is_seq:
img_rgb = Image.fromarray(img_arr, 'L').convert('RGB')
else:
img_rgb = Image.fromarray(img_arr).convert('RGB')
embeddings.append(encode_image(img_rgb, device, embed_model, processor))
# update progress bar
i += 1
pbar.progress(i/len(frames), pbar_text(i))
# save csv of single file
if save_csv:
df = pd.DataFrame(embeddings)
df['Frame'] = frames
# save csv
df.to_csv(f'embeddings_downsample_{downsample_rate}_{N}_frames.csv', index=False)
all_video_embeddings.append(np.array(embeddings))
all_video_frames.append(frames)
return all_video_embeddings, all_video_frames
def create_embeddings_csv(out: str,
fnames: list[str],
embeddings: list[np.ndarray],
frames: list[list[int]],
annotations: list[list[str]],
test_fnames: None | list[str],
views: None | list[str],
conditions: None | list[str],
downsample_rate = 4,
filesystem = None):
"""
Creates a .csv file containing all of the generated embeddings and provived information.
Parameters:
-----------
out : str
The name of the resulting file.
fnames : list[str]
Video sources for each of the embedding arrays.
embeddings : np.ndarray
The generated embeddings from the images.
downsample_rate : int
The downsample_rate used for generating the embeddings.
"""
assert len(fnames) == len(embeddings)
assert len(embeddings) == len(frames)
all_embeddings = np.vstack(embeddings)
df = pd.DataFrame(all_embeddings)
labels = []
for i, annot_fnames in enumerate(annotations):
_, ext = os.path.splitext(annot_fnames[0])
if ext == '.annot':
annot, _, _, sr = load_multiple_annotations(annot_fnames, filesystem=filesystem)
annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
elif ext == '.csv':
if not filesystem:
annot_df = pd.read_csv(annot_fnames[0], header=None)
else:
with filesystem.open(annot_fnames[0], 'r') as csv_file:
annot_df = pd.read_csv(csv_file, header=None)
annot_labels = annot_df[0].to_list()[::downsample_rate]
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
else:
raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
print(annot_labels)
labels.append(annot_labels)
all_labels = np.hstack(labels)
print(len(all_labels))
df['Label'] = all_labels
all_frames = np.hstack(frames)
df['Frame'] = all_frames
sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
all_sources = np.hstack(sources)
df['Source'] = all_sources
if test_fnames:
t_split = lambda x: True if x in test_fnames else False
test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
else:
test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
all_test = np.hstack(test)
df['Test'] = all_test
if views:
view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
else:
view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
all_view = np.hstack(view)
df['View'] = all_view
if conditions:
condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
else:
condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
all_condition = np.hstack(condition)
df['Condition'] = all_condition
return df
def create_embeddings_csv_io(out: str,
fnames: list[str],
embeddings: list[np.ndarray],
frames: list[list[int]],
annotations: list,
test_fnames: None | list[str],
views: None | list[str],
conditions: None | list[str],
downsample_rate = 4):
"""
Creates a .csv file containing all of the generated embeddings and provived information.
Parameters:
-----------
out : str
The name of the resulting file.
fnames : list[str]
Video sources for each of the embedding arrays.
embeddings : np.ndarray
The generated embeddings from the images.
downsample_rate : int
The downsample_rate used for generating the embeddings.
"""
assert len(fnames) == len(embeddings)
assert len(embeddings) == len(frames)
all_embeddings = np.vstack(embeddings)
df = pd.DataFrame(all_embeddings)
labels = []
for i, uploaded_annots in enumerate(annotations):
print(i)
_, ext = os.path.splitext(uploaded_annots[0].name)
if ext == '.annot':
annot, _, _, sr = load_multiple_annotations_io(uploaded_annots)
annot_labels = generate_label_array(annot, downsample_rate, len(frames[i]))
elif ext == '.csv':
annot_df = pd.read_csv(uploaded_annots[0], header=None)
annot_labels = annot_df[0].to_list()[::downsample_rate]
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure that the passed in csv file has no header."
else:
raise ValueError(f'Incompatible file for annotations used. Got a file of type "{ext}".')
assert len(annot_labels) == len(frames[i]), "There is a mismatch between the number of frames and number of labels. Make sure you have passed in the correct files."
print(annot_labels)
labels.append(annot_labels)
all_labels = np.hstack(labels)
print(len(all_labels))
df['Label'] = all_labels
all_frames = np.hstack(frames)
df['Frame'] = all_frames
sources = [[fname for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
all_sources = np.hstack(sources)
df['Source'] = all_sources
if test_fnames:
t_split = lambda x: True if x in test_fnames else False
test = [[t_split(fname) for _ in range(len(frames[i]))] for i, fname in enumerate(fnames)]
else:
test = [[True for _ in range(len(frames[i]))] for i, _ in enumerate(fnames)]
all_test = np.hstack(test)
df['Test'] = all_test
if views:
view = [[views[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
else:
view = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
all_view = np.hstack(view)
df['View'] = all_view
if conditions:
condition = [[conditions[i] for _ in range(len(frames[i]))] for i in range(len(fnames))]
else:
condition = [[None for _ in range(len(frames[i]))] for i in range(len(fnames))]
all_condition = np.hstack(condition)
df['Condition'] = all_condition
return df
def process_dataset_in_mem(embeddings_df: pd.DataFrame,
specified_classes=None,
classes_to_remove=None,
max_class_size=None,
animal_state=None,
view=None,
shuffle_data=False,
test_videos=None):
"""
Processes output generated from embeddings paired with images and behavior labels.
Parameters:
-----------
csv_path : str
Path to the file containing the original data. This should contain embeddings,
a column named `'Label'` and a column named `'Images'`.
specified_classes : None | list[str]
An optional input. Defines labels which should be kept as is in the `'Label'`
column and which should be changed to a default `other` label.
classes_to_remove : None | list[str]
An optional input. Drops rows from the dataframe which contain a label in the
list.
max_class_size : None | int
An optional input. Determines the maximum amount of rows a single label can
appear in for each unique label in the `'Label'` column.
animal_state : None | str
An optional input. Drops rows from the dataframe which do not contain a match
for `animal_state` in the text field within the `'Images'` column.
view : None | str
An optional input. Drops rows from the dataframe which do not contain a match
for `view` in the text field within the `'Images'` column.
shuffle_data : bool
Determines wether the dataframe should have its rows shuffled.
test_videos : None | list[str]
An optional input. Determines what rows should be in the `test` dataframe, and
which should be in the `train` dataframe. It drops rows from the respective
dataframe by keeping or dropping rows which do not contain a match for a `str`
in `test_videos` in the text field within the `'Images'` column, respectively.
Returns:
--------
balanced_train_embeddings : pandas.DataFrame
A processed dataframe whose rows contain the embeddings for each of the images
at the corresponding index within `balanced_train_images`.
balanced_train_labels : list[str]
A list of labels for each of the images at the corresponing index within
`balanced_train_images`.
balanced_train_images: list[str]
A list of paths to images with each image at an index corresponding to a label
with the same index in `balanced_train_labels` and the same row index within
`balanced_train_embeddings`.
test_embeddings : pandas.DataFrame
A processed dataframe whose rows contain the embeddings for each of the images
at the corresponding index within `test_images`.
test_labels : list[str]
A list of labels for each of the images at the corresponing index within
`test_images`.
test_images : list[str]
A list of paths to images with each image at an index corresponding to a label
with the same index in `test_labels` and the same row index within
`test_embeddings`.
"""
# Convert embeddings, labels, and images to a DataFrame for easy manipulation
df = copy.deepcopy(embeddings_df)
df_keys = [str(x) for x in df.keys()]
#Filter by fed or fasted
if 'Condition' in df_keys and animal_state:
df = df[df['Condition'].str.contains(animal_state, na=False)]
if 'View' in df_keys and view:
df = df[df['View'].str.contains(view, na=False)]
# Extract unique video names excluding the frame number
#unique_video_names = df['Images'].apply(lambda x: '_'.join(x.split('_')[:-1])).unique()
#print("\nUnique video names:\n", unique_video_names)
if classes_to_remove:
df = df[~df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
elif classes_to_remove and 'all' in classes_to_remove:
df = df[df['Label'].str.contains('|'.join(classes_to_remove), na=False)]
# Further filter to include only specified_classes
if specified_classes:
single_match = lambda x: list(set(x.split('||')) & set(specified_classes))[0]
df['Label'] = df['Label'].apply(lambda x: single_match(x) if not set(x.split('||')).isdisjoint(specified_classes) else 'other')
specified_classes.append('other')
# Separate the DataFrame into test and training sets based on test_videos
if 'Test' in df_keys and test_videos:
test_df = df[df['Test']]
train_df = df[~df['Test']]
elif test_videos:
test_df = df[df['Images'].str.contains('|'.join(test_videos), na=False)]
train_df = df[~df['Images'].str.contains('|'.join(test_videos), na=False)]
else:
test_df = pd.DataFrame(columns=df.columns)
train_df = df
# Print the number of frames in each class before balancing
label_counts = train_df['Label'].value_counts()
print("\nNumber of training frames in each class before balancing:")
print(label_counts)
if max_class_size:
balanced_train_df = pd.concat([
group.sample(n=min(len(group), max_class_size), random_state=1)
for label, group in train_df.groupby('Label')
])
else:
balanced_train_df = train_df
# Shuffle the training DataFrame
if shuffle_data:
balanced_train_df = balanced_train_df.sample(frac=1).reset_index(drop=True)
# Convert training set back to numpy array and list
if not "Images" in df_keys:
balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
balanced_train_labels = balanced_train_df['Label'].tolist()
balanced_train_images = balanced_train_df['Frame'].tolist()
# Convert test set back to numpy array and list
test_embeddings = test_df.drop(columns=['Label', 'Frame', 'Source', 'Test','View','Condition']).to_numpy()
test_labels = test_df['Label'].tolist()
test_images = test_df['Frame'].tolist()
else:
# Convert training set back to numpy array and list
balanced_train_embeddings = balanced_train_df.drop(columns=['Label', 'Images']).to_numpy()
balanced_train_labels = balanced_train_df['Label'].tolist()
balanced_train_images = balanced_train_df['Images'].tolist()
# Convert test set back to numpy array and list
if 'Test' in test_df:
test_embeddings = test_df.drop(columns=['Label', 'Images', 'Test']).to_numpy()
else:
test_embeddings = test_df.drop(columns=['Label', 'Images']).to_numpy()
test_labels = test_df['Label'].tolist()
test_images = test_df['Images'].tolist()
# Print the number of frames in each class after balancing
if specified_classes or max_class_size:
balanced_label_counts = Counter(balanced_train_labels)
print("\nNumber of training frames in each class after balancing:")
print(balanced_label_counts)
test_label_counts = test_df['Label'].value_counts()
# print("\nNumber of testing frames in each class:")
print(test_label_counts)
return balanced_train_embeddings, balanced_train_labels, balanced_train_images, test_embeddings, test_labels, test_images
def multiclass_merge_and_filter_bouts(multiclass_vector, bout_threshold, proximity_threshold):
# Get the unique labels in the multiclass vector (excluding zero, assuming zero is the background/no label)
unique_labels = np.unique(multiclass_vector)
unique_labels = unique_labels[unique_labels != 0]
# Initialize a vector to store the merged and filtered multiclass vector
merged_vector = np.zeros_like(multiclass_vector)
for label in unique_labels:
# Create a binary vector for the current label
binary_vector = (multiclass_vector == label)
# Find the start and end indices of all sequences of 1's for this label
starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]
# Step 1: Merge close short bouts
i = 0
while i < len(starts) - 1:
# Check if the gap between the end of the current bout and the start of the next bout
# is within the proximity threshold
if starts[i + 1] - ends[i] <= proximity_threshold:
# Merge the two bouts by setting all elements between the start of the first
# and the end of the second bout to 1
binary_vector[ends[i]:starts[i + 1]] = 1
# Remove the next bout from consideration
starts = np.delete(starts, i + 1)
ends = np.delete(ends, i)
else:
i += 1
# Update the starts and ends after merging
starts = np.where(np.diff(np.concatenate(([0], binary_vector))) == 1)[0]
ends = np.where(np.diff(np.concatenate((binary_vector, [0]))) == -1)[0]
# Step 2: Remove standalone short bouts
for i in range(len(starts)):
# Check the length of the bout
length_of_bout = ends[i] - starts[i] + 1
# If the length is less than the threshold, set those elements to 0
if length_of_bout < bout_threshold:
binary_vector[starts[i]:ends[i] + 1] = 0
# Combine the binary vector with the merged_vector, ensuring only the current label is set
merged_vector[binary_vector] = label
# Return the filtered multiclass vector
return merged_vector
def get_unique_labels(label_list: list[str]):
label_set = set()
for label in label_list:
individual_labels = label.split('||')
for individual_label in individual_labels:
label_set.add(individual_label)
return list(label_set)
def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42):
return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state)
def train_model(X_train, y_train, random_state=42):
# Train SVM Classifier
svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True)
svm_clf.fit(X_train, y_train)
return svm_clf
def pickle_model(model):
pickled = io.BytesIO()
pickle.dump(model, pickled)
return pickled
def get_seq_io_reader(uploaded_file):
assert uploaded_file.name[-3:]=='seq', 'Not a seq file'
with NamedTemporaryFile(suffix="seq", delete=False) as temp:
temp.write(uploaded_file.getvalue())
sr = seqIo_reader(temp.name)
return sr
def seq_to_arr(sr):
N = sr.header['numFrames']
images = []
for f in range(N):
I, ts = sr.getFrame(f)
images.append(I)
return np.array(images)
def get_2d_embedding(embeddings: pd.DataFrame):
tsne = TSNE(n_jobs=4, n_components=2, random_state=42, perplexity=50)
embedding_2d = tsne.fit_transform(np.array(embeddings))
return embedding_2d