Spaces:
Sleeping
Sleeping
import os | |
import io | |
import pickle | |
import regex | |
import streamlit as st | |
import plotly.express as px | |
import numpy as np | |
import pandas as pd | |
import torch | |
from utils.seqIo import seqIo_reader | |
import pandas as pd | |
from PIL import Image | |
from pathlib import Path | |
from transformers import AutoProcessor, AutoModel | |
from tqdm import tqdm | |
from sklearn.svm import SVC | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import accuracy_score, classification_report | |
from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, generate_embeddings_stream_io | |
# --server.maxUploadSize 3000 | |
def get_unique_labels(label_list: list[str]): | |
label_set = set() | |
for label in label_list: | |
individual_labels = label.split('||') | |
for individual_label in individual_labels: | |
label_set.add(individual_label) | |
return list(label_set) | |
def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42): | |
return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state) | |
def train_model(X_train, y_train, random_state=42): | |
# Train SVM Classifier | |
svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True, verbose=True) | |
svm_clf.fit(X_train, y_train) | |
return svm_clf | |
def pickle_model(model): | |
pickled = io.BytesIO() | |
pickle.dump(model, pickled) | |
return pickled | |
if "embeddings_df_train" not in st.session_state: | |
st.session_state.embeddings_df_train = None | |
if "svm_clf" not in st.session_state: | |
st.session_state.svm_clf = None | |
st.session_state.report_df = None | |
st.session_state.accuracy = None | |
st.title('batik: frame classifier training') | |
st.text("Upload files to train classifier on.") | |
with st.form('embedding_generation_settings'): | |
seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4']) | |
annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True) | |
downsample_rate = st.number_input('Downsample Rate',value=4) | |
submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary') | |
st.markdown("**(Optional)** Upload embeddings.") | |
embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv']) | |
if submit_embed_settings and seq_file is not None and annot_files is not None: | |
video_embeddings, video_frames = generate_embeddings_stream_io([seq_file], | |
"SLIP", | |
downsample_rate, | |
False) | |
fnames = [seq_file.name] | |
embeddings_df = create_embeddings_csv_io(out="file", | |
fnames=fnames, | |
embeddings=video_embeddings, | |
frames=video_frames, | |
annotations=[annot_files], | |
test_fnames=None, | |
views=None, | |
conditions=None, | |
downsample_rate=downsample_rate) | |
st.session_state.embeddings_df_train = embeddings_df | |
elif embeddings_csv is not None: | |
embeddings_df = pd.read_csv(embeddings_csv) | |
st.session_state.embeddings_df_train = embeddings_df | |
else: | |
st.text('Please upload file(s).') | |
st.divider() | |
if st.session_state.embeddings_df_train is not None: | |
st.subheader("specify dataset preprocessing options") | |
st.text("Select frames with label(s) to include:") | |
with st.form('train_settings'): | |
label_list = st.session_state.embeddings_df_train['Label'].to_list() | |
unique_label_list = get_unique_labels(label_list) | |
specified_classes = st.multiselect("Label(s) included:", options=unique_label_list) | |
st.text("Select label(s) that should be removed:") | |
classes_to_remove = st.multiselect("Label(s) excluded:", options=unique_label_list) | |
max_class_size = st.number_input("(Optional) Specify max class size:", value=None) | |
shuffle_data = st.toggle("Shuffle data:") | |
train_model_clicked = st.form_submit_button("Train Model") | |
if train_model_clicked: | |
kwargs = {'embeddings_df' : st.session_state.embeddings_df_train, | |
'specified_classes' : specified_classes, | |
'classes_to_remove' : classes_to_remove, | |
'max_class_size' : max_class_size, | |
'animal_state' : None, | |
'view' : None, | |
'shuffle_data' : shuffle_data, | |
'test_videos' : None} | |
train_embeds, train_labels, train_images, _, _, _ = process_dataset_in_mem(**kwargs) | |
# Convert labels to numerical values | |
label_to_appear_first = 'other' | |
unique_labels = set(train_labels) | |
unique_labels.discard(label_to_appear_first) | |
label_to_index = {label_to_appear_first: 0} | |
label_to_index.update({label: idx + 1 for idx, label in enumerate(unique_labels)}) | |
index_to_label = {idx: label for label, idx in label_to_index.items()} | |
numerical_labels = np.array([label_to_index[label] for label in train_labels]) | |
print("Label Valence: ", label_to_index) | |
# Split data into train and test sets | |
X_train, X_test, y_train, y_test = get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42) | |
with st.spinner("Model training in progress..."): | |
svm_clf = train_model(X_train, y_train) | |
# Predict on the test set | |
with st.spinner("In progress..."): | |
y_pred = svm_clf.predict(X_test) | |
accuracy = accuracy_score(y_test, y_pred) | |
report = classification_report(y_test, y_pred, target_names=[index_to_label[idx] for idx in range(len(label_to_index))], output_dict=True) | |
report_df = pd.DataFrame(report).transpose() | |
# save results to session state | |
st.session_state.svm_clf = svm_clf | |
st.session_state.report_df = report_df | |
st.session_state.accuracy = accuracy | |
if st.session_state.svm_clf is not None: | |
pickled_model = pickle_model(st.session_state.svm_clf) | |
st.text(f"Eval Accuracy: {st.session_state.accuracy}") | |
st.subheader("Classification Report:") | |
st.dataframe(st.session_state.report_df) | |
st.download_button("Download model as .pkl file", | |
data=pickled_model, | |
file_name=f"{'_'.join(specified_classes)}_classifier.pkl") | |