import streamlit as st
import pandas as pd
from utils.utils import create_embeddings_csv_io, create_annot_fname_dict_io, generate_embeddings_stream_io

if "video_embeddings" not in st.session_state:
    st.session_state.video_embeddings = None
    st.session_state.video_frames = None
    st.session_state.fnames = []

st.title('batik: embedding generator')
uploaded_files = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=True)
with st.form('initial_settings'):
    st.header('Embedding Generation Options')
    model_select = st.selectbox('Select Model', ['SLIP', 'CLIP'])
    downsample_rate = st.number_input('Downsample Rate',value=4)
    #save_csv = st.toggle('Save Individual Results', value=False)
    submit_initial_settings = st.form_submit_button('Create Embeddings', type='secondary')

if submit_initial_settings and uploaded_files is not None and len(uploaded_files) > 0:
    video_embeddings, video_frames = generate_embeddings_stream_io(uploaded_files,
                                                                model_select,
                                                                downsample_rate,
                                                                False)
    fnames = [vid_file.name for vid_file in uploaded_files]
    st.session_state.video_embeddings = video_embeddings
    st.session_state.video_frames = video_frames
    st.session_state.fnames = fnames

if st.session_state.video_embeddings is not None:
    st.header('CSV Configuration Options')
    st.markdown('If using `.annot` files and multiple files should be grouped together, '\
                'please ensure that they share a common name and end with a number describing '\
                'the order of the files. For example:\n\n'\
                '`mouse_224_file_1.annot`, `mouse_224_file_2.annot`.')
    annot_files = st.file_uploader("Upload all annotation files", type=['.annot','.csv'], accept_multiple_files=True)

    annot_options = []
    if annot_files is not None and len(annot_files) > 0:
        annot_fnames = [annot_file.name for annot_file in annot_files]
        annot_fname_dict = create_annot_fname_dict_io(annot_fnames=annot_fnames,
                                                      annot_files=annot_files)
        annot_options = [str(key) for key in annot_fname_dict.keys()]

    if len(annot_options) > 0:
        with st.form('csv_settings'):
            csv_setting_def = pd.DataFrame(
                {
                    "File Name" : st.session_state.fnames,
                    "Annotations" : [
                        "Upload File" for _ in st.session_state.fnames
                    ],
                    "Test" : [
                        False for _ in st.session_state.fnames
                    ],
                    "View" : [
                        "Top" for _ in st.session_state.fnames
                    ],
                    "Condition" : [
                        "None" for _ in st.session_state.fnames
                    ]

                }
            )

            csv_settings = st.data_editor(
                csv_setting_def,
                column_config={
                    "Annotations" : st.column_config.SelectboxColumn(
                        "Annotations",
                        help="The annotation file(s) to use for the given video file.",
                        width="medium",
                        options=annot_options,
                        required=True
                    ),
                    "Test" : st.column_config.CheckboxColumn(
                        "Test",
                        help="Designate file(s) to use as the test set.",
                        default=False,
                        required=True
                    ),
                    "View" : st.column_config.SelectboxColumn(
                        "View",
                        help="The view used within the video (either Top or Front).",
                        options=["Top", "Front"],
                        required=True
                    ),
                    "Condition" : st.column_config.TextColumn(
                        "Condition",
                        help="A condition the video has (i.e. Control).",
                        default="None",
                        max_chars=30,
                        validate=r"[a-z]+$",
                    )
                },
                hide_index=True
            )
            save_csv_bttn = st.form_submit_button("Create CSV")

        if save_csv_bttn and csv_settings is not None:
            annot_chosen_options = csv_settings['Annotations'].tolist()
            annot_option = [annot_fname_dict[key] for key in annot_chosen_options]
            test_chosen_option = csv_settings['Test'].tolist()
            test_option = [st.session_state.fnames[i] for i, is_test in enumerate(test_chosen_option) if is_test]
            view_option = csv_settings['View'].tolist()
            condition_option =  csv_settings['Condition'].tolist()

            out_name = st.text_input("Embeddings Outpit File Name", "out.csv")
            try:
                df = create_embeddings_csv_io(out=out_name,
                                            fnames=st.session_state.fnames,
                                            embeddings=st.session_state.video_embeddings,
                                            frames=st.session_state.video_frames,
                                            annotations=annot_option,
                                            test_fnames=test_option,
                                            views=view_option,
                                            conditions=condition_option,
                                            downsample_rate=downsample_rate)
                st.success('Created Embeddings File!', icon="✅")
                st.download_button(
                    label="Download CSV",
                    data=df.to_csv().encode("utf-8"),
                    file_name=out_name,
                    mime="text/csv"
                )
            except:
                st.error('Something went wrong.')

    else:
        st.text('Please Upload Files')
else:
    st.text('Please Upload Files')