Spaces:
Sleeping
Sleeping
replace state key for embeddings_df
Browse filesneeded to replace key so that state wouldn't get messed up between pages
- apply_model.py +11 -11
apply_model.py
CHANGED
@@ -49,8 +49,8 @@ def get_smoothed_predictions(svm_model, test_embeds):
|
|
49 |
predictions = multiclass_merge_and_filter_bouts(test_pred, bout_threshold, proximity_threshold)
|
50 |
return predictions
|
51 |
|
52 |
-
if "
|
53 |
-
st.session_state.
|
54 |
|
55 |
if "smoothed_predictions" not in st.session_state:
|
56 |
st.session_state.smoothed_predictions = None
|
@@ -84,11 +84,11 @@ if submit_embed_settings and seq_file is not None and annot_files is not None:
|
|
84 |
views=None,
|
85 |
conditions=None,
|
86 |
downsample_rate=downsample_rate)
|
87 |
-
st.session_state.
|
88 |
|
89 |
elif embeddings_csv is not None:
|
90 |
embeddings_df = pd.read_csv(embeddings_csv)
|
91 |
-
st.session_state.
|
92 |
else:
|
93 |
st.text('Please upload file(s).')
|
94 |
|
@@ -105,9 +105,9 @@ else:
|
|
105 |
svm_clf = None
|
106 |
|
107 |
st.divider()
|
108 |
-
if st.session_state.
|
109 |
st.subheader("specify dataset labels")
|
110 |
-
label_list = st.session_state.
|
111 |
unique_label_list = get_unique_labels(label_list)
|
112 |
|
113 |
with st.form('apply_model_settings'):
|
@@ -118,15 +118,15 @@ if st.session_state.embeddings_df is not None and svm_clf is not None:
|
|
118 |
apply_model = st.form_submit_button("Apply Model")
|
119 |
|
120 |
if apply_model:
|
121 |
-
if 'Test' in st.session_state.
|
122 |
test_videos = True
|
123 |
else:
|
124 |
-
print(f'shape of df: {st.session_state.
|
125 |
-
test_videos_array = [True for i in range(st.session_state.
|
126 |
-
st.session_state.
|
127 |
test_videos = True
|
128 |
|
129 |
-
kwargs = {'embeddings_df' : st.session_state.
|
130 |
'specified_classes' : specified_classes,
|
131 |
'classes_to_remove' : None,
|
132 |
'max_class_size' : None,
|
|
|
49 |
predictions = multiclass_merge_and_filter_bouts(test_pred, bout_threshold, proximity_threshold)
|
50 |
return predictions
|
51 |
|
52 |
+
if "embeddings_df_apply" not in st.session_state:
|
53 |
+
st.session_state.embeddings_df_apply = None
|
54 |
|
55 |
if "smoothed_predictions" not in st.session_state:
|
56 |
st.session_state.smoothed_predictions = None
|
|
|
84 |
views=None,
|
85 |
conditions=None,
|
86 |
downsample_rate=downsample_rate)
|
87 |
+
st.session_state.embeddings_df_apply = embeddings_df
|
88 |
|
89 |
elif embeddings_csv is not None:
|
90 |
embeddings_df = pd.read_csv(embeddings_csv)
|
91 |
+
st.session_state.embeddings_df_apply = embeddings_df
|
92 |
else:
|
93 |
st.text('Please upload file(s).')
|
94 |
|
|
|
105 |
svm_clf = None
|
106 |
|
107 |
st.divider()
|
108 |
+
if st.session_state.embeddings_df_apply is not None and svm_clf is not None:
|
109 |
st.subheader("specify dataset labels")
|
110 |
+
label_list = st.session_state.embeddings_df_apply['Label'].to_list()
|
111 |
unique_label_list = get_unique_labels(label_list)
|
112 |
|
113 |
with st.form('apply_model_settings'):
|
|
|
118 |
apply_model = st.form_submit_button("Apply Model")
|
119 |
|
120 |
if apply_model:
|
121 |
+
if 'Test' in st.session_state.embeddings_df_apply:
|
122 |
test_videos = True
|
123 |
else:
|
124 |
+
print(f'shape of df: {st.session_state.embeddings_df_apply.shape[0]}')
|
125 |
+
test_videos_array = [True for i in range(st.session_state.embeddings_df_apply.shape[0])]
|
126 |
+
st.session_state.embeddings_df_apply['Test'] = test_videos_array
|
127 |
test_videos = True
|
128 |
|
129 |
+
kwargs = {'embeddings_df' : st.session_state.embeddings_df_apply,
|
130 |
'specified_classes' : specified_classes,
|
131 |
'classes_to_remove' : None,
|
132 |
'max_class_size' : None,
|