Spaces:
Sleeping
Sleeping
import os | |
import pickle | |
from random import random | |
import streamlit as st | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import ListedColormap | |
import numpy as np | |
import pandas as pd | |
import torch | |
from utils.mp4Io import mp4Io_reader | |
from utils.seqIo import seqIo_reader | |
import pandas as pd | |
from PIL import Image | |
from pathlib import Path | |
from transformers import AutoProcessor, AutoModel | |
from tempfile import NamedTemporaryFile | |
from tqdm import tqdm | |
from sklearn.metrics import accuracy_score, classification_report | |
from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, multiclass_merge_and_filter_bouts, generate_embeddings_stream_io | |
# --server.maxUploadSize 3000 | |
def get_io_reader(uploaded_file): | |
if uploaded_file.name[-3:]=='seq': | |
with NamedTemporaryFile(suffix="seq", delete=False) as temp: | |
temp.write(uploaded_file.getvalue()) | |
sr = seqIo_reader(temp.name) | |
else: | |
with NamedTemporaryFile(suffix="mp4", delete=False) as temp: | |
temp.write(uploaded_file.getvalue()) | |
sr = mp4Io_reader(temp.name) | |
return sr | |
def get_unique_labels(label_list: list[str]): | |
label_set = set() | |
for label in label_list: | |
individual_labels = label.split('||') | |
for individual_label in individual_labels: | |
label_set.add(individual_label) | |
return list(label_set) | |
def get_smoothed_predictions(svm_model, test_embeds): | |
test_pred = svm_model.predict(test_embeds) | |
test_prob = svm_model.predict_proba(test_embeds) | |
bout_threshold = 30 | |
proximity_threshold = 2 | |
predictions = multiclass_merge_and_filter_bouts(test_pred, bout_threshold, proximity_threshold) | |
return predictions | |
if "embeddings_df_apply" not in st.session_state: | |
st.session_state.embeddings_df_apply = None | |
if "smoothed_predictions" not in st.session_state: | |
st.session_state.smoothed_predictions = None | |
st.session_state.test_labels = [] | |
st.title('batik: frame classifier') | |
st.text("Upload files to apply trained classifier on.") | |
with st.form('embedding_generation_settings'): | |
seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4']) | |
annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True) | |
downsample_rate = st.number_input('Downsample Rate',value=4) | |
submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary') | |
st.markdown("**(Optional)** Upload embeddings if not generating above.") | |
embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv']) | |
if submit_embed_settings and seq_file is not None and annot_files is not None: | |
video_embeddings, video_frames = generate_embeddings_stream_io([seq_file], | |
"SLIP", | |
downsample_rate, | |
False) | |
fnames = [seq_file.name] | |
embeddings_df = create_embeddings_csv_io(out="file", | |
fnames=fnames, | |
embeddings=video_embeddings, | |
frames=video_frames, | |
annotations=[annot_files], | |
test_fnames=None, | |
views=None, | |
conditions=None, | |
downsample_rate=downsample_rate) | |
st.session_state.embeddings_df_apply = embeddings_df | |
elif embeddings_csv is not None: | |
embeddings_df = pd.read_csv(embeddings_csv) | |
st.session_state.embeddings_df_apply = embeddings_df | |
else: | |
st.text('Please upload file(s).') | |
st.divider() | |
st.markdown("Upload classifier model.") | |
pickled_file = st.file_uploader("Choose a .pkl File", type=['pkl']) | |
if pickled_file is not None: | |
with NamedTemporaryFile(suffix='pkl', delete=False) as temp: | |
temp.write(pickled_file.getvalue()) | |
with open(temp.name, 'rb') as pickled_model: | |
svm_clf = pickle.load(pickled_model) | |
else: | |
svm_clf = None | |
st.divider() | |
if st.session_state.embeddings_df_apply is not None and svm_clf is not None: | |
st.subheader("specify dataset labels") | |
label_list = st.session_state.embeddings_df_apply['Label'].to_list() | |
unique_label_list = get_unique_labels(label_list) | |
with st.form('apply_model_settings'): | |
st.text("Select label(s):") | |
specified_classes = st.multiselect("Label(s) included:", options=unique_label_list) | |
apply_model = st.form_submit_button("Apply Model") | |
if apply_model: | |
if 'Test' in st.session_state.embeddings_df_apply: | |
test_videos = True | |
else: | |
print(f'shape of df: {st.session_state.embeddings_df_apply.shape[0]}') | |
test_videos_array = [True for i in range(st.session_state.embeddings_df_apply.shape[0])] | |
st.session_state.embeddings_df_apply['Test'] = test_videos_array | |
test_videos = True | |
kwargs = {'embeddings_df' : st.session_state.embeddings_df_apply, | |
'specified_classes' : specified_classes, | |
'classes_to_remove' : None, | |
'max_class_size' : None, | |
'animal_state' : None, | |
'view' : None, | |
'shuffle_data' : False, | |
'test_videos' : test_videos} | |
train_embeds, train_labels, train_images, test_embeds, test_labels, test_images =\ | |
process_dataset_in_mem(**kwargs) | |
# get predictions from embeddings | |
with st.spinner("Model application in progress..."): | |
smoothed_predictions = get_smoothed_predictions(svm_clf, test_embeds) | |
# save variables to state | |
st.session_state.smoothed_predictions = smoothed_predictions | |
st.session_state.test_labels = test_labels | |
if st.session_state.smoothed_predictions is not None: | |
# Convert labels to numerical values | |
label_to_appear_first = 'other' | |
unique_labels = set(st.session_state.test_labels) | |
unique_labels.discard(label_to_appear_first) | |
label_to_index = {label_to_appear_first: 0} | |
label_to_index.update({label: idx + 1 for idx, label in enumerate(unique_labels)}) | |
index_to_label = {idx: label for label, idx in label_to_index.items()} | |
numerical_labels_test = np.array([label_to_index[label] for label in st.session_state.test_labels]) | |
print("Label Valence: ", label_to_index) | |
#smoothed_predictions test labels | |
if len(st.session_state.smoothed_predictions) > 0: | |
test_accuracy = accuracy_score(numerical_labels_test, st.session_state.smoothed_predictions) | |
else: | |
test_accuracy = 0 # If no predictions meet the threshold, set accuracy to 0 | |
# test_accuracy = accuracy_score(numerical_labels_test, test_pred) | |
report = classification_report(numerical_labels_test, | |
st.session_state.smoothed_predictions, | |
target_names=[index_to_label[idx] for idx in range(len(index_to_label))], | |
output_dict=True) | |
report_df = pd.DataFrame(report).transpose() | |
st.text(f"Eval Accuracy: {test_accuracy}") | |
st.subheader("Classification Report:") | |
st.dataframe(report_df) | |
# create figure (behavior raster) | |
fig, ax = plt.subplots() | |
raster = ax.imshow(st.session_state.smoothed_predictions.reshape((1,st.session_state.smoothed_predictions.size)), | |
aspect='auto', | |
interpolation='nearest', | |
cmap=ListedColormap(['white'] + [(random(),random(),random()) for i in range(len(index_to_label) - 1)])) | |
ax.set_yticklabels([]) | |
ax.set_xlabel('frames') | |
cbar = fig.colorbar(raster) | |
labels = [label_to_appear_first] + list(unique_labels) | |
spacing = (len(labels) - 1)/len(labels) | |
start = spacing/2 | |
ticks = [start] + [start + spacing*i for i in range(1,len(labels))] | |
cbar.set_ticks(ticks=ticks, labels = labels) | |
st.pyplot(fig) | |
# save generated annotations | |
annotations = [labels[x] for x in st.session_state.smoothed_predictions] | |
annotations_df = pd.DataFrame(annotations, columns=['label']) | |
csv = annotations_df.to_csv(header=False).encode("utf-8") | |
output_file_name = st.text_input("Output File Name:","output") | |
st.download_button("Download annotations as .csv", | |
data=csv, | |
file_name=f"{output_file_name}.csv") | |