rmm commited on
Commit
4854d2c
·
1 Parent(s): 7a5f0ca

feat: using FSM for full workflow, with some steps mocked

Browse files

- dropped the "ML running" phase for now as we don't do it async

src/classifier/classifier_image.py CHANGED
@@ -11,6 +11,15 @@ from hf_push_observations import push_observations
11
  from utils.grid_maker import gridder
12
  from utils.metadata_handler import metadata2md
13
 
 
 
 
 
 
 
 
 
 
14
  def cetacean_classify(cetacean_classifier):
15
  """Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space.
16
  For each image in the session state, classify the image and display the top 3 predictions.
 
11
  from utils.grid_maker import gridder
12
  from utils.metadata_handler import metadata2md
13
 
14
+ def add_header_text() -> None:
15
+ """
16
+ Add brief explainer text about cetacean classification to the tab
17
+ """
18
+ st.markdown("""
19
+ *Run classifer to identify the species of cetean on the uploaded image.
20
+ Once inference is complete, the top three predictions are shown.
21
+ You can override the prediction by selecting a species from the dropdown.*""")
22
+
23
  def cetacean_classify(cetacean_classifier):
24
  """Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space.
25
  For each image in the session state, classify the image and display the top 3 predictions.
src/main.py CHANGED
@@ -9,7 +9,8 @@ from streamlit_folium import st_folium
9
  from transformers import pipeline
10
  from transformers import AutoModelForImageClassification
11
 
12
- from maps.obs_map import add_header_text
 
13
  from datasets import disable_caching
14
  disable_caching()
15
 
@@ -79,18 +80,20 @@ if "workflow_fsm" not in st.session_state:
79
  # create and init the state machine
80
  st.session_state.workflow_fsm = WorkflowFSM(FSM_STATES)
81
 
82
- # add progress indicator to session_state
83
- if "progress" not in st.session_state:
84
- with st.sidebar:
85
- st.session_state.disp_progress = [st.empty(), st.empty()]
86
-
87
  def refresh_progress():
88
  with st.sidebar:
89
- tot = st.session_state.workflow_fsm.num_states
90
  cur_i = st.session_state.workflow_fsm.current_state_index
91
  cur_t = st.session_state.workflow_fsm.current_state
92
  st.session_state.disp_progress[0].markdown(f"*Progress: {cur_i}/{tot}. Current: {cur_t}.*")
93
  st.session_state.disp_progress[1].progress(cur_i/tot)
 
 
 
 
 
 
 
94
 
95
 
96
  def main() -> None:
@@ -125,10 +128,8 @@ def main() -> None:
125
  st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "*:gray[Dev:coordinates]*", "Log", "Beautiful cetaceans"])
126
  st.session_state.tab_log = tab_log
127
 
 
128
  refresh_progress()
129
- # add button to sidebar, with the callback to refesh_progress
130
- st.sidebar.button("Refresh Progress", on_click=refresh_progress)
131
-
132
 
133
  # create a sidebar, and parse all the input (returned as `observations` object)
134
  setup_input(viewcontainer=st.sidebar)
@@ -149,7 +150,7 @@ def main() -> None:
149
  with tab_map:
150
  # visual structure: a couple of toggles at the top, then the map inlcuding a
151
  # dropdown for tileset selection.
152
- add_header_text()
153
  tab_map_ui_cols = st.columns(2)
154
  with tab_map_ui_cols[0]:
155
  show_db_points = st.toggle("Show Points from DB", True)
@@ -207,24 +208,108 @@ def main() -> None:
207
  gallery.render_whale_gallery(n_cols=4)
208
 
209
 
210
- # Display submitted observation
211
- all_inputs_set = check_inputs_are_set(debug=True)
212
- if not all_inputs_set:
213
- st.sidebar.button(":gray[*Validate*]", disabled=True, help="Please fill in all fields.")
214
-
215
- else:
216
- if st.session_state.workflow_fsm.is_in_state('init'):
 
 
 
 
 
217
  st.session_state.workflow_fsm.complete_current_state()
218
-
219
- if st.sidebar.button("**Validate**"):
220
- if st.session_state.workflow_fsm.is_in_state('data_entry_complete'):
221
- st.session_state.workflow_fsm.complete_current_state()
222
-
 
 
 
 
223
  # create a dictionary with the submitted observation
224
  tab_log.info(f"{st.session_state.observations}")
225
  df = pd.DataFrame(st.session_state.observations, index=[0])
226
  with tab_coords:
227
  st.table(df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
 
230
 
@@ -235,23 +320,24 @@ def main() -> None:
235
  # - these species are shown
236
  # - the user can override the species prediction using the dropdown
237
  # - an observation is uploaded if the user chooses.
238
- tab_inference.markdown("""
239
- *Run classifer to identify the species of cetean on the uploaded image.
240
- Once inference is complete, the top three predictions are shown.
241
- You can override the prediction by selecting a species from the dropdown.*""")
242
 
243
- if tab_inference.button("Identify with cetacean classifier"):
244
- #pipe = pipeline("image-classification", model="Saving-Willy/cetacean-classifier", trust_remote_code=True)
245
- cetacean_classifier = AutoModelForImageClassification.from_pretrained("Saving-Willy/cetacean-classifier",
246
- revision=classifier_revision,
247
- trust_remote_code=True)
248
 
249
 
250
- if st.session_state.images is None:
251
- # TODO: cleaner design to disable the button until data input done?
252
- st.info("Please upload an image first.")
253
- else:
254
- cetacean_classify(cetacean_classifier)
 
 
 
 
 
 
 
255
 
256
 
257
 
 
9
  from transformers import pipeline
10
  from transformers import AutoModelForImageClassification
11
 
12
+ from maps.obs_map import add_header_text as add_obs_map_header
13
+ from classifier.classifier_image import add_header_text as add_classifier_header
14
  from datasets import disable_caching
15
  disable_caching()
16
 
 
80
  # create and init the state machine
81
  st.session_state.workflow_fsm = WorkflowFSM(FSM_STATES)
82
 
 
 
 
 
 
83
  def refresh_progress():
84
  with st.sidebar:
85
+ tot = st.session_state.workflow_fsm.num_states - 1
86
  cur_i = st.session_state.workflow_fsm.current_state_index
87
  cur_t = st.session_state.workflow_fsm.current_state
88
  st.session_state.disp_progress[0].markdown(f"*Progress: {cur_i}/{tot}. Current: {cur_t}.*")
89
  st.session_state.disp_progress[1].progress(cur_i/tot)
90
+ # add progress indicator to session_state
91
+ if "progress" not in st.session_state:
92
+ with st.sidebar:
93
+ st.session_state.disp_progress = [st.empty(), st.empty()]
94
+ # add button to sidebar, with the callback to refesh_progress
95
+ st.sidebar.button("Refresh Progress", on_click=refresh_progress)
96
+
97
 
98
 
99
  def main() -> None:
 
128
  st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "*:gray[Dev:coordinates]*", "Log", "Beautiful cetaceans"])
129
  st.session_state.tab_log = tab_log
130
 
131
+ # put this early so the progress indicator is at the top (also refreshed at end)
132
  refresh_progress()
 
 
 
133
 
134
  # create a sidebar, and parse all the input (returned as `observations` object)
135
  setup_input(viewcontainer=st.sidebar)
 
150
  with tab_map:
151
  # visual structure: a couple of toggles at the top, then the map inlcuding a
152
  # dropdown for tileset selection.
153
+ add_obs_map_header()
154
  tab_map_ui_cols = st.columns(2)
155
  with tab_map_ui_cols[0]:
156
  show_db_points = st.toggle("Show Points from DB", True)
 
208
  gallery.render_whale_gallery(n_cols=4)
209
 
210
 
211
+ # state handling re data_entry phases
212
+ # 0. no data entered yet -> display the file uploader thing
213
+ # 1. we have some images, but not all the metadata fields are done -> validate button shown, disabled
214
+ # 2. all data entered -> validate button enabled
215
+ # 3. validation button pressed, validation done -> enable the inference button.
216
+ # - at this point do we also want to disable changes to the metadata selectors?
217
+ # anyway, simple first.
218
+
219
+ if st.session_state.workflow_fsm.is_in_state('doing_data_entry'):
220
+ # can we advance state? - only when all inputs are set for all uploaded files
221
+ all_inputs_set = check_inputs_are_set(debug=True)
222
+ if all_inputs_set:
223
  st.session_state.workflow_fsm.complete_current_state()
224
+ # -> data_entry_complete
225
+ else:
226
+ # button, disabled; no state change yet.
227
+ st.sidebar.button(":gray[*Validate*]", disabled=True, help="Please fill in all fields.")
228
+
229
+
230
+ if st.session_state.workflow_fsm.is_in_state('data_entry_complete'):
231
+ # can we advance state? - only when the validate button is pressed
232
+ if st.sidebar.button(":white_check_mark:[*Validate*]"):
233
  # create a dictionary with the submitted observation
234
  tab_log.info(f"{st.session_state.observations}")
235
  df = pd.DataFrame(st.session_state.observations, index=[0])
236
  with tab_coords:
237
  st.table(df)
238
+ # there doesn't seem to be any actual validation here?? TODO: find validator function (each element is validated by the input box, but is there something at the whole image level?)
239
+ # hmm, maybe it should actually just be "I'm done with data entry"
240
+ st.session_state.workflow_fsm.complete_current_state()
241
+ # -> data_entry_validated
242
+
243
+ # state handling re inference phases (tab_inference)
244
+ # 3. validation button pressed, validation done -> enable the inference button.
245
+ # 4. inference button pressed -> ML started. | let's cut this one out, since it would only
246
+ # make sense if we did it as an async action
247
+ # 5. ML done -> show results, and manual validation options
248
+ # 6. manual validation done -> enable the upload buttons
249
+ #
250
+ with tab_inference:
251
+ add_classifier_header()
252
+ # if we are before data_entry_validated, show the button, disabled.
253
+ if not st.session_state.workflow_fsm.is_in_state_or_beyond('data_entry_validated'):
254
+ tab_inference.button(":gray[*Identify with cetacean classifier*]", disabled=True,
255
+ help="Please validate inputs before proceeding",
256
+ key="button_infer_ceteans")
257
+
258
+ if st.session_state.workflow_fsm.is_in_state('data_entry_validated'):
259
+ # show the button, enabled. If pressed, we start the ML model (And advance state)
260
+ if tab_inference.button("Identify with cetacean classifier"):
261
+ cetacean_classifier = AutoModelForImageClassification.from_pretrained(
262
+ "Saving-Willy/cetacean-classifier",
263
+ revision=classifier_revision,
264
+ trust_remote_code=True)
265
+
266
+ cetacean_classify(cetacean_classifier)
267
+ st.session_state.workflow_fsm.complete_current_state()
268
+
269
+ if st.session_state.workflow_fsm.is_in_state('ml_classification_completed'):
270
+ # show the results, and allow manual validation
271
+ s = ""
272
+ for k, v in st.session_state.whale_prediction1.items():
273
+ s += f"* Image {k}: {v}\n"
274
+
275
+ st.markdown("""
276
+ ### Inference Results and manual validation/adjustment
277
+ :construction: for now we just show the num images processed.
278
+ """)
279
+ st.markdown(s)
280
+ # add a button to advance the state
281
+ if st.button("mock: manual validation done."):
282
+ st.session_state.workflow_fsm.complete_current_state()
283
+ # -> manual_inspection_completed
284
+
285
+ if st.session_state.workflow_fsm.is_in_state('manual_inspection_completed'):
286
+ # show the ML results, and allow the user to upload the observation
287
+ st.markdown("""
288
+ ### Inference Results (after manual validation)
289
+ :construction: for now we just show the button.
290
+ """)
291
+
292
+
293
+ if st.button("(nooop) Upload observation to THE INTERNET!"):
294
+ st.session_state.workflow_fsm.complete_current_state()
295
+ # -> data_uploaded
296
+
297
+ if st.session_state.workflow_fsm.is_in_state('data_uploaded'):
298
+ # the data has been sent. Lets show the observations again
299
+ # but no buttons to upload (or greyed out ok)
300
+ st.markdown("""
301
+ ### Observation(s) uploaded
302
+ :construction: for now we just show the observations.
303
+ """)
304
+ df = pd.DataFrame(st.session_state.observations, index=[0])
305
+ st.table(df)
306
+
307
+ # didn't decide what the next state is here - I think we are in the terminal state.
308
+ #st.session_state.workflow_fsm.complete_current_state()
309
+
310
+
311
+
312
+
313
 
314
 
315
 
 
320
  # - these species are shown
321
  # - the user can override the species prediction using the dropdown
322
  # - an observation is uploaded if the user chooses.
323
+
324
+ # with tab_inference:
325
+ # add_classifier_header()
 
326
 
 
 
 
 
 
327
 
328
 
329
+ # if tab_inference.button("Identify with cetacean classifier"):
330
+ # #pipe = pipeline("image-classification", model="Saving-Willy/cetacean-classifier", trust_remote_code=True)
331
+ # cetacean_classifier = AutoModelForImageClassification.from_pretrained("Saving-Willy/cetacean-classifier",
332
+ # revision=classifier_revision,
333
+ # trust_remote_code=True)
334
+
335
+
336
+ # if st.session_state.images is None:
337
+ # # TODO: cleaner design to disable the button until data input done?
338
+ # st.info("Please upload an image first.")
339
+ # else:
340
+ # cetacean_classify(cetacean_classifier)
341
 
342
 
343
 
src/utils/workflow_state.py CHANGED
@@ -8,7 +8,10 @@ FAIL = '\033[91m'
8
  ENDC = '\033[0m'
9
 
10
 
11
- FSM_STATES = ['init', 'data_entry_complete', 'data_entry_validated', 'ml_classification_started', 'ml_classification_completed', 'manual_inspection_completed', 'data_uploaded']
 
 
 
12
 
13
 
14
  class WorkflowFSM:
@@ -64,6 +67,19 @@ class WorkflowFSM:
64
  return False
65
  return False
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  @property
68
  def current_state(self) -> str:
69
  """Get the current state name"""
 
8
  ENDC = '\033[0m'
9
 
10
 
11
+ FSM_STATES = ['doing_data_entry', 'data_entry_complete', 'data_entry_validated',
12
+ #'ml_classification_started',
13
+ 'ml_classification_completed',
14
+ 'manual_inspection_completed', 'data_uploaded']
15
 
16
 
17
  class WorkflowFSM:
 
67
  return False
68
  return False
69
 
70
+ # add a helper method, to find out if a given state has been reached/passed
71
+ # we first need to get the index of the current state
72
+ # then the index of the argument state
73
+ # compare, and return boolean
74
+
75
+ def is_in_state_or_beyond(self, state_name: str) -> bool:
76
+ """Check if we have reached or passed the specified state"""
77
+ if state_name not in self.state_dict:
78
+ raise ValueError(f"Invalid state: {state_name}")
79
+
80
+ return self.state_dict[state_name] <= self.state_dict[self.state]
81
+
82
+
83
  @property
84
  def current_state(self) -> str:
85
  """Get the current state name"""