File size: 6,245 Bytes
8cfd894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import streamlit as st
import pandas as pd
from utils.utils import create_embeddings_csv_io, create_annot_fname_dict_io, generate_embeddings_stream_io

if "video_embeddings" not in st.session_state:
    st.session_state.video_embeddings = None
    st.session_state.video_frames = None
    st.session_state.fnames = []

st.title('batik: embedding generator')
uploaded_files = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=True)
with st.form('initial_settings'):
    st.header('Embedding Generation Options')
    model_select = st.selectbox('Select Model', ['SLIP', 'CLIP'])
    downsample_rate = st.number_input('Downsample Rate',value=4)
    #save_csv = st.toggle('Save Individual Results', value=False)
    submit_initial_settings = st.form_submit_button('Create Embeddings', type='secondary')

if submit_initial_settings and uploaded_files is not None and len(uploaded_files) > 0:
    video_embeddings, video_frames = generate_embeddings_stream_io(uploaded_files,
                                                                model_select,
                                                                downsample_rate,
                                                                False)
    fnames = [vid_file.name for vid_file in uploaded_files]
    st.session_state.video_embeddings = video_embeddings
    st.session_state.video_frames = video_frames
    st.session_state.fnames = fnames

if st.session_state.video_embeddings is not None:
    st.header('CSV Configuration Options')
    st.markdown('If using `.annot` files and multiple files should be grouped together, '\
                'please ensure that they share a common name and end with a number describing '\
                'the order of the files. For example:\n\n'\
                '`mouse_224_file_1.annot`, `mouse_224_file_2.annot`.')
    annot_files = st.file_uploader("Upload all annotation files", type=['.annot','.csv'], accept_multiple_files=True)

    annot_options = []
    if annot_files is not None and len(annot_files) > 0:
        annot_fnames = [annot_file.name for annot_file in annot_files]
        annot_fname_dict = create_annot_fname_dict_io(annot_fnames=annot_fnames,
                                                      annot_files=annot_files)
        annot_options = [str(key) for key in annot_fname_dict.keys()]

    if len(annot_options) > 0:
        with st.form('csv_settings'):
            csv_setting_def = pd.DataFrame(
                {
                    "File Name" : st.session_state.fnames,
                    "Annotations" : [
                        "Upload File" for _ in st.session_state.fnames
                    ],
                    "Test" : [
                        False for _ in st.session_state.fnames
                    ],
                    "View" : [
                        "Top" for _ in st.session_state.fnames
                    ],
                    "Condition" : [
                        "None" for _ in st.session_state.fnames
                    ]

                }
            )

            csv_settings = st.data_editor(
                csv_setting_def,
                column_config={
                    "Annotations" : st.column_config.SelectboxColumn(
                        "Annotations",
                        help="The annotation file(s) to use for the given video file.",
                        width="medium",
                        options=annot_options,
                        required=True
                    ),
                    "Test" : st.column_config.CheckboxColumn(
                        "Test",
                        help="Designate file(s) to use as the test set.",
                        default=False,
                        required=True
                    ),
                    "View" : st.column_config.SelectboxColumn(
                        "View",
                        help="The view used within the video (either Top or Front).",
                        options=["Top", "Front"],
                        required=True
                    ),
                    "Condition" : st.column_config.TextColumn(
                        "Condition",
                        help="A condition the video has (i.e. Control).",
                        default="None",
                        max_chars=30,
                        validate=r"[a-z]+$",
                    )
                },
                hide_index=True
            )
            save_csv_bttn = st.form_submit_button("Create CSV")

        if save_csv_bttn and csv_settings is not None:
            annot_chosen_options = csv_settings['Annotations'].tolist()
            annot_option = [annot_fname_dict[key] for key in annot_chosen_options]
            test_chosen_option = csv_settings['Test'].tolist()
            test_option = [st.session_state.fnames[i] for i, is_test in enumerate(test_chosen_option) if is_test]
            view_option = csv_settings['View'].tolist()
            condition_option =  csv_settings['Condition'].tolist()

            out_name = st.text_input("Embeddings Outpit File Name", "out.csv")
            try:
                df = create_embeddings_csv_io(out=out_name,
                                            fnames=st.session_state.fnames,
                                            embeddings=st.session_state.video_embeddings,
                                            frames=st.session_state.video_frames,
                                            annotations=annot_option,
                                            test_fnames=test_option,
                                            views=view_option,
                                            conditions=condition_option,
                                            downsample_rate=downsample_rate)
                st.success('Created Embeddings File!', icon="✅")
                st.download_button(
                    label="Download CSV",
                    data=df.to_csv().encode("utf-8"),
                    file_name=out_name,
                    mime="text/csv"
                )
            except:
                st.error('Something went wrong.')

    else:
        st.text('Please Upload Files')
else:
    st.text('Please Upload Files')