ncoria commited on
Commit
3198317
·
verified ·
1 Parent(s): 4dbe38d

update explore and apply

Browse files
Files changed (2) hide show
  1. apply_model.py +7 -1
  2. explore.py +2 -2
apply_model.py CHANGED
@@ -118,6 +118,12 @@ if st.session_state.embeddings_df is not None and svm_clf is not None:
118
  apply_model = st.form_submit_button("Apply Model")
119
 
120
  if apply_model:
 
 
 
 
 
 
121
  kwargs = {'embeddings_df' : st.session_state.embeddings_df,
122
  'specified_classes' : specified_classes,
123
  'classes_to_remove' : None,
@@ -125,7 +131,7 @@ if st.session_state.embeddings_df is not None and svm_clf is not None:
125
  'animal_state' : None,
126
  'view' : None,
127
  'shuffle_data' : False,
128
- 'test_videos' : list(set(st.session_state.embeddings_df['Source'].to_list()))}
129
  train_embeds, train_labels, train_images, test_embeds, test_labels, test_images =\
130
  process_dataset_in_mem(**kwargs)
131
 
 
118
  apply_model = st.form_submit_button("Apply Model")
119
 
120
  if apply_model:
121
+ if 'Test' in st.session_state.embeddings_df.index:
122
+ test_videos = True
123
+ elif 'Images' in st.session_state.embeddings_df.index:
124
+ test_videos = True
125
+ else:
126
+ test_videos = False
127
  kwargs = {'embeddings_df' : st.session_state.embeddings_df,
128
  'specified_classes' : specified_classes,
129
  'classes_to_remove' : None,
 
131
  'animal_state' : None,
132
  'view' : None,
133
  'shuffle_data' : False,
134
+ 'test_videos' : test_videos}
135
  train_embeds, train_labels, train_images, test_embeds, test_labels, test_images =\
136
  process_dataset_in_mem(**kwargs)
137
 
explore.py CHANGED
@@ -177,7 +177,7 @@ st.subheader("generate or import embeddings")
177
 
178
  st.text("Upload files to generate embeddings.")
179
  with st.form('embedding_generation_settings'):
180
- seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=True)
181
  annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True)
182
  downsample_rate = st.number_input('Downsample Rate',value=4)
183
  submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary')
@@ -211,7 +211,7 @@ else:
211
  st.divider()
212
  st.subheader("provide video file if not yet already provided")
213
 
214
- uploaded_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=True)
215
 
216
  st.divider()
217
  if st.session_state.embeddings_df is not None and (uploaded_file is not None or seq_file is not None):
 
177
 
178
  st.text("Upload files to generate embeddings.")
179
  with st.form('embedding_generation_settings'):
180
+ seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=False)
181
  annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True)
182
  downsample_rate = st.number_input('Downsample Rate',value=4)
183
  submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary')
 
211
  st.divider()
212
  st.subheader("provide video file if not yet already provided")
213
 
214
+ uploaded_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'], accept_multiple_files=False)
215
 
216
  st.divider()
217
  if st.session_state.embeddings_df is not None and (uploaded_file is not None or seq_file is not None):