ncoria commited on
Commit
389f3f5
·
verified ·
1 Parent(s): 8cfd894

replace state key for embeddings_df

Browse files

needed to replace key so that state wouldn't get messed up between pages

Files changed (1) hide show
  1. apply_model.py +11 -11
apply_model.py CHANGED
@@ -49,8 +49,8 @@ def get_smoothed_predictions(svm_model, test_embeds):
49
  predictions = multiclass_merge_and_filter_bouts(test_pred, bout_threshold, proximity_threshold)
50
  return predictions
51
 
52
- if "embeddings_df" not in st.session_state:
53
- st.session_state.embeddings_df = None
54
 
55
  if "smoothed_predictions" not in st.session_state:
56
  st.session_state.smoothed_predictions = None
@@ -84,11 +84,11 @@ if submit_embed_settings and seq_file is not None and annot_files is not None:
84
  views=None,
85
  conditions=None,
86
  downsample_rate=downsample_rate)
87
- st.session_state.embeddings_df = embeddings_df
88
 
89
  elif embeddings_csv is not None:
90
  embeddings_df = pd.read_csv(embeddings_csv)
91
- st.session_state.embeddings_df = embeddings_df
92
  else:
93
  st.text('Please upload file(s).')
94
 
@@ -105,9 +105,9 @@ else:
105
  svm_clf = None
106
 
107
  st.divider()
108
- if st.session_state.embeddings_df is not None and svm_clf is not None:
109
  st.subheader("specify dataset labels")
110
- label_list = st.session_state.embeddings_df['Label'].to_list()
111
  unique_label_list = get_unique_labels(label_list)
112
 
113
  with st.form('apply_model_settings'):
@@ -118,15 +118,15 @@ if st.session_state.embeddings_df is not None and svm_clf is not None:
118
  apply_model = st.form_submit_button("Apply Model")
119
 
120
  if apply_model:
121
- if 'Test' in st.session_state.embeddings_df:
122
  test_videos = True
123
  else:
124
- print(f'shape of df: {st.session_state.embeddings_df.shape[0]}')
125
- test_videos_array = [True for i in range(st.session_state.embeddings_df.shape[0])]
126
- st.session_state.embeddings_df['Test'] = test_videos_array
127
  test_videos = True
128
 
129
- kwargs = {'embeddings_df' : st.session_state.embeddings_df,
130
  'specified_classes' : specified_classes,
131
  'classes_to_remove' : None,
132
  'max_class_size' : None,
 
49
  predictions = multiclass_merge_and_filter_bouts(test_pred, bout_threshold, proximity_threshold)
50
  return predictions
51
 
52
+ if "embeddings_df_apply" not in st.session_state:
53
+ st.session_state.embeddings_df_apply = None
54
 
55
  if "smoothed_predictions" not in st.session_state:
56
  st.session_state.smoothed_predictions = None
 
84
  views=None,
85
  conditions=None,
86
  downsample_rate=downsample_rate)
87
+ st.session_state.embeddings_df_apply = embeddings_df
88
 
89
  elif embeddings_csv is not None:
90
  embeddings_df = pd.read_csv(embeddings_csv)
91
+ st.session_state.embeddings_df_apply = embeddings_df
92
  else:
93
  st.text('Please upload file(s).')
94
 
 
105
  svm_clf = None
106
 
107
  st.divider()
108
+ if st.session_state.embeddings_df_apply is not None and svm_clf is not None:
109
  st.subheader("specify dataset labels")
110
+ label_list = st.session_state.embeddings_df_apply['Label'].to_list()
111
  unique_label_list = get_unique_labels(label_list)
112
 
113
  with st.form('apply_model_settings'):
 
118
  apply_model = st.form_submit_button("Apply Model")
119
 
120
  if apply_model:
121
+ if 'Test' in st.session_state.embeddings_df_apply:
122
  test_videos = True
123
  else:
124
+ print(f'shape of df: {st.session_state.embeddings_df_apply.shape[0]}')
125
+ test_videos_array = [True for i in range(st.session_state.embeddings_df_apply.shape[0])]
126
+ st.session_state.embeddings_df_apply['Test'] = test_videos_array
127
  test_videos = True
128
 
129
+ kwargs = {'embeddings_df' : st.session_state.embeddings_df_apply,
130
  'specified_classes' : specified_classes,
131
  'classes_to_remove' : None,
132
  'max_class_size' : None,