import os import io import pickle import regex import streamlit as st import plotly.express as px import numpy as np import pandas as pd import torch from utils.seqIo import seqIo_reader import pandas as pd from PIL import Image from pathlib import Path from transformers import AutoProcessor, AutoModel from tqdm import tqdm from sklearn.svm import SVC from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, classification_report from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, generate_embeddings_stream_io # --server.maxUploadSize 3000 def get_unique_labels(label_list: list[str]): label_set = set() for label in label_list: individual_labels = label.split('||') for individual_label in individual_labels: label_set.add(individual_label) return list(label_set) @st.cache_data def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42): return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state) @st.cache_resource def train_model(X_train, y_train, random_state=42): # Train SVM Classifier svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True, verbose=True) svm_clf.fit(X_train, y_train) return svm_clf def pickle_model(model): pickled = io.BytesIO() pickle.dump(model, pickled) return pickled if "embeddings_df_train" not in st.session_state: st.session_state.embeddings_df_train = None if "svm_clf" not in st.session_state: st.session_state.svm_clf = None st.session_state.report_df = None st.session_state.accuracy = None st.title('batik: frame classifier training') st.text("Upload files to train classifier on.") with st.form('embedding_generation_settings'): seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4']) annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True) downsample_rate = st.number_input('Downsample Rate',value=4) submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary') st.markdown("**(Optional)** Upload embeddings.") embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv']) if submit_embed_settings and seq_file is not None and annot_files is not None: video_embeddings, video_frames = generate_embeddings_stream_io([seq_file], "SLIP", downsample_rate, False) fnames = [seq_file.name] embeddings_df = create_embeddings_csv_io(out="file", fnames=fnames, embeddings=video_embeddings, frames=video_frames, annotations=[annot_files], test_fnames=None, views=None, conditions=None, downsample_rate=downsample_rate) st.session_state.embeddings_df_train = embeddings_df elif embeddings_csv is not None: embeddings_df = pd.read_csv(embeddings_csv) st.session_state.embeddings_df_train = embeddings_df else: st.text('Please upload file(s).') st.divider() if st.session_state.embeddings_df_train is not None: st.subheader("specify dataset preprocessing options") st.text("Select frames with label(s) to include:") with st.form('train_settings'): label_list = st.session_state.embeddings_df_train['Label'].to_list() unique_label_list = get_unique_labels(label_list) specified_classes = st.multiselect("Label(s) included:", options=unique_label_list) st.text("Select label(s) that should be removed:") classes_to_remove = st.multiselect("Label(s) excluded:", options=unique_label_list) max_class_size = st.number_input("(Optional) Specify max class size:", value=None) shuffle_data = st.toggle("Shuffle data:") train_model_clicked = st.form_submit_button("Train Model") if train_model_clicked: kwargs = {'embeddings_df' : st.session_state.embeddings_df_train, 'specified_classes' : specified_classes, 'classes_to_remove' : classes_to_remove, 'max_class_size' : max_class_size, 'animal_state' : None, 'view' : None, 'shuffle_data' : shuffle_data, 'test_videos' : None} train_embeds, train_labels, train_images, _, _, _ = process_dataset_in_mem(**kwargs) # Convert labels to numerical values label_to_appear_first = 'other' unique_labels = set(train_labels) unique_labels.discard(label_to_appear_first) label_to_index = {label_to_appear_first: 0} label_to_index.update({label: idx + 1 for idx, label in enumerate(unique_labels)}) index_to_label = {idx: label for label, idx in label_to_index.items()} numerical_labels = np.array([label_to_index[label] for label in train_labels]) print("Label Valence: ", label_to_index) # Split data into train and test sets X_train, X_test, y_train, y_test = get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42) with st.spinner("Model training in progress..."): svm_clf = train_model(X_train, y_train) # Predict on the test set with st.spinner("In progress..."): y_pred = svm_clf.predict(X_test) accuracy = accuracy_score(y_test, y_pred) report = classification_report(y_test, y_pred, target_names=[index_to_label[idx] for idx in range(len(label_to_index))], output_dict=True) report_df = pd.DataFrame(report).transpose() # save results to session state st.session_state.svm_clf = svm_clf st.session_state.report_df = report_df st.session_state.accuracy = accuracy if st.session_state.svm_clf is not None: pickled_model = pickle_model(st.session_state.svm_clf) st.text(f"Eval Accuracy: {st.session_state.accuracy}") st.subheader("Classification Report:") st.dataframe(st.session_state.report_df) st.download_button("Download model as .pkl file", data=pickled_model, file_name=f"{'_'.join(specified_classes)}_classifier.pkl")