import os import pickle from random import random import streamlit as st import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap import numpy as np import pandas as pd import torch from utils.mp4Io import mp4Io_reader from utils.seqIo import seqIo_reader import pandas as pd from PIL import Image from pathlib import Path from transformers import AutoProcessor, AutoModel from tempfile import NamedTemporaryFile from tqdm import tqdm from sklearn.metrics import accuracy_score, classification_report from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, multiclass_merge_and_filter_bouts, generate_embeddings_stream_io # --server.maxUploadSize 3000 def get_io_reader(uploaded_file): if uploaded_file.name[-3:]=='seq': with NamedTemporaryFile(suffix="seq", delete=False) as temp: temp.write(uploaded_file.getvalue()) sr = seqIo_reader(temp.name) else: with NamedTemporaryFile(suffix="mp4", delete=False) as temp: temp.write(uploaded_file.getvalue()) sr = mp4Io_reader(temp.name) return sr def get_unique_labels(label_list: list[str]): label_set = set() for label in label_list: individual_labels = label.split('||') for individual_label in individual_labels: label_set.add(individual_label) return list(label_set) def get_smoothed_predictions(svm_model, test_embeds): test_pred = svm_model.predict(test_embeds) test_prob = svm_model.predict_proba(test_embeds) bout_threshold = 30 proximity_threshold = 2 predictions = multiclass_merge_and_filter_bouts(test_pred, bout_threshold, proximity_threshold) return predictions if "embeddings_df_apply" not in st.session_state: st.session_state.embeddings_df_apply = None if "smoothed_predictions" not in st.session_state: st.session_state.smoothed_predictions = None st.session_state.test_labels = [] st.title('batik: frame classifier') st.text("Upload files to apply trained classifier on.") with st.form('embedding_generation_settings'): seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4']) annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True) downsample_rate = st.number_input('Downsample Rate',value=4) submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary') st.markdown("**(Optional)** Upload embeddings if not generating above.") embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv']) if submit_embed_settings and seq_file is not None and annot_files is not None: video_embeddings, video_frames = generate_embeddings_stream_io([seq_file], "SLIP", downsample_rate, False) fnames = [seq_file.name] embeddings_df = create_embeddings_csv_io(out="file", fnames=fnames, embeddings=video_embeddings, frames=video_frames, annotations=[annot_files], test_fnames=None, views=None, conditions=None, downsample_rate=downsample_rate) st.session_state.embeddings_df_apply = embeddings_df elif embeddings_csv is not None: embeddings_df = pd.read_csv(embeddings_csv) st.session_state.embeddings_df_apply = embeddings_df else: st.text('Please upload file(s).') st.divider() st.markdown("Upload classifier model.") pickled_file = st.file_uploader("Choose a .pkl File", type=['pkl']) if pickled_file is not None: with NamedTemporaryFile(suffix='pkl', delete=False) as temp: temp.write(pickled_file.getvalue()) with open(temp.name, 'rb') as pickled_model: svm_clf = pickle.load(pickled_model) else: svm_clf = None st.divider() if st.session_state.embeddings_df_apply is not None and svm_clf is not None: st.subheader("specify dataset labels") label_list = st.session_state.embeddings_df_apply['Label'].to_list() unique_label_list = get_unique_labels(label_list) with st.form('apply_model_settings'): st.text("Select label(s):") specified_classes = st.multiselect("Label(s) included:", options=unique_label_list) apply_model = st.form_submit_button("Apply Model") if apply_model: if 'Test' in st.session_state.embeddings_df_apply: test_videos = True else: print(f'shape of df: {st.session_state.embeddings_df_apply.shape[0]}') test_videos_array = [True for i in range(st.session_state.embeddings_df_apply.shape[0])] st.session_state.embeddings_df_apply['Test'] = test_videos_array test_videos = True kwargs = {'embeddings_df' : st.session_state.embeddings_df_apply, 'specified_classes' : specified_classes, 'classes_to_remove' : None, 'max_class_size' : None, 'animal_state' : None, 'view' : None, 'shuffle_data' : False, 'test_videos' : test_videos} train_embeds, train_labels, train_images, test_embeds, test_labels, test_images =\ process_dataset_in_mem(**kwargs) # get predictions from embeddings with st.spinner("Model application in progress..."): smoothed_predictions = get_smoothed_predictions(svm_clf, test_embeds) # save variables to state st.session_state.smoothed_predictions = smoothed_predictions st.session_state.test_labels = test_labels if st.session_state.smoothed_predictions is not None: # Convert labels to numerical values label_to_appear_first = 'other' unique_labels = set(st.session_state.test_labels) unique_labels.discard(label_to_appear_first) label_to_index = {label_to_appear_first: 0} label_to_index.update({label: idx + 1 for idx, label in enumerate(unique_labels)}) index_to_label = {idx: label for label, idx in label_to_index.items()} numerical_labels_test = np.array([label_to_index[label] for label in st.session_state.test_labels]) print("Label Valence: ", label_to_index) #smoothed_predictions test labels if len(st.session_state.smoothed_predictions) > 0: test_accuracy = accuracy_score(numerical_labels_test, st.session_state.smoothed_predictions) else: test_accuracy = 0 # If no predictions meet the threshold, set accuracy to 0 # test_accuracy = accuracy_score(numerical_labels_test, test_pred) report = classification_report(numerical_labels_test, st.session_state.smoothed_predictions, target_names=[index_to_label[idx] for idx in range(len(index_to_label))], output_dict=True) report_df = pd.DataFrame(report).transpose() st.text(f"Eval Accuracy: {test_accuracy}") st.subheader("Classification Report:") st.dataframe(report_df) # create figure (behavior raster) fig, ax = plt.subplots() raster = ax.imshow(st.session_state.smoothed_predictions.reshape((1,st.session_state.smoothed_predictions.size)), aspect='auto', interpolation='nearest', cmap=ListedColormap(['white'] + [(random(),random(),random()) for i in range(len(index_to_label) - 1)])) ax.set_yticklabels([]) ax.set_xlabel('frames') cbar = fig.colorbar(raster) labels = [label_to_appear_first] + list(unique_labels) spacing = (len(labels) - 1)/len(labels) start = spacing/2 ticks = [start] + [start + spacing*i for i in range(1,len(labels))] cbar.set_ticks(ticks=ticks, labels = labels) st.pyplot(fig) # save generated annotations annotations = [labels[x] for x in st.session_state.smoothed_predictions] annotations_df = pd.DataFrame(annotations, columns=['label']) csv = annotations_df.to_csv(header=False).encode("utf-8") output_file_name = st.text_input("Output File Name:","output") st.download_button("Download annotations as .csv", data=csv, file_name=f"{output_file_name}.csv")