import streamlit as st import pandas as pd from utils.utils import create_embeddings_csv_io, create_annot_fname_dict_io, generate_embeddings_stream_io if "video_embeddings" not in st.session_state: st.session_state.video_embeddings = None st.session_state.video_frames = None st.session_state.fnames = [] st.title('batik: embedding generator') uploaded_files = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=True) with st.form('initial_settings'): st.header('Embedding Generation Options') model_select = st.selectbox('Select Model', ['SLIP', 'CLIP']) downsample_rate = st.number_input('Downsample Rate',value=4) #save_csv = st.toggle('Save Individual Results', value=False) submit_initial_settings = st.form_submit_button('Create Embeddings', type='secondary') if submit_initial_settings and uploaded_files is not None and len(uploaded_files) > 0: video_embeddings, video_frames = generate_embeddings_stream_io(uploaded_files, model_select, downsample_rate, False) fnames = [vid_file.name for vid_file in uploaded_files] st.session_state.video_embeddings = video_embeddings st.session_state.video_frames = video_frames st.session_state.fnames = fnames if st.session_state.video_embeddings is not None: st.header('CSV Configuration Options') st.markdown('If using `.annot` files and multiple files should be grouped together, '\ 'please ensure that they share a common name and end with a number describing '\ 'the order of the files. For example:\n\n'\ '`mouse_224_file_1.annot`, `mouse_224_file_2.annot`.') annot_files = st.file_uploader("Upload all annotation files", type=['.annot','.csv'], accept_multiple_files=True) annot_options = [] if annot_files is not None and len(annot_files) > 0: annot_fnames = [annot_file.name for annot_file in annot_files] annot_fname_dict = create_annot_fname_dict_io(annot_fnames=annot_fnames, annot_files=annot_files) annot_options = [str(key) for key in annot_fname_dict.keys()] if len(annot_options) > 0: with st.form('csv_settings'): csv_setting_def = pd.DataFrame( { "File Name" : st.session_state.fnames, "Annotations" : [ "Upload File" for _ in st.session_state.fnames ], "Test" : [ False for _ in st.session_state.fnames ], "View" : [ "Top" for _ in st.session_state.fnames ], "Condition" : [ "None" for _ in st.session_state.fnames ] } ) csv_settings = st.data_editor( csv_setting_def, column_config={ "Annotations" : st.column_config.SelectboxColumn( "Annotations", help="The annotation file(s) to use for the given video file.", width="medium", options=annot_options, required=True ), "Test" : st.column_config.CheckboxColumn( "Test", help="Designate file(s) to use as the test set.", default=False, required=True ), "View" : st.column_config.SelectboxColumn( "View", help="The view used within the video (either Top or Front).", options=["Top", "Front"], required=True ), "Condition" : st.column_config.TextColumn( "Condition", help="A condition the video has (i.e. Control).", default="None", max_chars=30, validate=r"[a-z]+$", ) }, hide_index=True ) save_csv_bttn = st.form_submit_button("Create CSV") if save_csv_bttn and csv_settings is not None: annot_chosen_options = csv_settings['Annotations'].tolist() annot_option = [annot_fname_dict[key] for key in annot_chosen_options] test_chosen_option = csv_settings['Test'].tolist() test_option = [st.session_state.fnames[i] for i, is_test in enumerate(test_chosen_option) if is_test] view_option = csv_settings['View'].tolist() condition_option = csv_settings['Condition'].tolist() out_name = st.text_input("Embeddings Outpit File Name", "out.csv") try: df = create_embeddings_csv_io(out=out_name, fnames=st.session_state.fnames, embeddings=st.session_state.video_embeddings, frames=st.session_state.video_frames, annotations=annot_option, test_fnames=test_option, views=view_option, conditions=condition_option, downsample_rate=downsample_rate) st.success('Created Embeddings File!', icon="✅") st.download_button( label="Download CSV", data=df.to_csv().encode("utf-8"), file_name=out_name, mime="text/csv" ) except: st.error('Something went wrong.') else: st.text('Please Upload Files') else: st.text('Please Upload Files')