Spaces:
Running
Running
File size: 7,184 Bytes
55d18b1 71dfd99 55d18b1 19b7ec2 5a3cbc4 f1504f4 f87663a 2fd6040 f87663a 2fd6040 71dfd99 2fd6040 8c4b1f7 2fd6040 f87663a 2fd6040 71dfd99 2fd6040 8c4b1f7 2fd6040 71dfd99 2fd6040 71dfd99 5f5f7d9 71dfd99 8c4b1f7 2fd6040 84afa3b 2fd6040 8c4b1f7 2fd6040 f87663a 2fd6040 71dfd99 2fd6040 8c4b1f7 2fd6040 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import streamlit as st
import logging
# get a global var for logger accessor in this module
LOG_LEVEL = logging.DEBUG
g_logger = logging.getLogger(__name__)
g_logger.setLevel(LOG_LEVEL)
import whale_viewer as viewer
from hf_push_observations import push_observations
from utils.grid_maker import gridder
from utils.metadata_handler import metadata2md
from input.input_observation import InputObservation
def init_classifier_session_states() -> None:
'''
Initialise the session state variables used in classification
'''
if "classify_whale_done" not in st.session_state:
st.session_state.classify_whale_done = {}
if "whale_prediction1" not in st.session_state:
st.session_state.whale_prediction1 = {}
def add_classifier_header() -> None:
"""
Add brief explainer text about cetacean classification to the tab
"""
st.markdown("""
*Run classifer to identify the species of cetean on the uploaded image.
Once inference is complete, the top three predictions are shown.
You can override the prediction by selecting a species from the dropdown.*""")
# func to just run classification, store results.
def cetacean_just_classify(cetacean_classifier):
"""
Infer cetacean species for all observations in the session state.
- this function runs the classifier, and stores results in the session state.
- the top 3 predictions are stored in the observation object, which is retained
in st.session_state.observations
- to display results use cetacean_show_results() or cetacean_show_results_and_review()
Args:
cetacean_classifier ([type]): saving-willy model from Saving Willy Hugging Face space
"""
images = st.session_state.images
#observations = st.session_state.observations
hashes = st.session_state.image_hashes
for hash in hashes:
image = images[hash]
# run classifier model on `image`, and persistently store the output
out = cetacean_classifier(image) # get top 3 matches
st.session_state.whale_prediction1[hash] = out['predictions'][0]
st.session_state.classify_whale_done[hash] = True
st.session_state.observations[hash].set_top_predictions(out['predictions'][:])
msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
g_logger.info(msg)
if st.session_state.MODE_DEV_STATEFUL:
st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
# func to show results and allow review
def cetacean_show_results_and_review() -> None:
"""
Present classification results and allow user to review and override the prediction.
- for each observation in the session state, displays the image, summarised
metadata, and the top 3 predictions.
- allows user to override the prediction by selecting a species from the dropdown.
- the selected species is stored in the observation object, which is retained in
st.session_state.observations
"""
images = st.session_state.images
observations = st.session_state.observations
hashes = st.session_state.image_hashes
batch_size, row_size, page = gridder(hashes)
grid = st.columns(row_size)
col = 0
o = 1
for hash in hashes:
image = images[hash]
#observation = observations[hash].to_dict()
_observation:InputObservation = observations[hash]
with grid[col]:
st.image(image, use_column_width=True)
# dropdown for selecting/overriding the species prediction
if not st.session_state.classify_whale_done[hash]:
selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
index=None, placeholder="Species not yet identified...",
disabled=True)
else:
pred1 = st.session_state.whale_prediction1[hash]
# get index of pred1 from WHALE_CLASSES, none if not present
print(f"[D] {o:3} pred1: {pred1:30} | {hash}")
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
_observation.set_selected_class(selected_class)
# store the elements of the observation that will be transmitted (not image)
observation = _observation.to_dict()
st.session_state.public_observations[hash] = observation
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
# TODO: the metadata only fills properly if `validate` was clicked.
# TODO put condition on the debug
st.markdown(metadata2md(hash, debug=False))
msg = f"[D] full observation after inference: {observation}"
g_logger.debug(msg)
print(msg)
# TODO: add a link to more info on the model, next to the button.
whale_classes = observations[hash].top_predictions
# render images for the top 3 (that is what the model api returns)
n = len(whale_classes)
st.markdown(f"**Top {n} Predictions for observation {str(o)}**")
for i in range(n):
viewer.display_whale(whale_classes, i)
o += 1
col = (col + 1) % row_size
# func to just present results
def cetacean_show_results():
"""
Present classification results that may be pushed to the online dataset.
- for each observation in the session state, displays the image, summarised
metadata, the top 3 predictions, and the selected species (which may have
been manually selected, or the top prediction accepted).
"""
images = st.session_state.images
observations = st.session_state.observations
hashes = st.session_state.image_hashes
batch_size, row_size, page = gridder(hashes)
grid = st.columns(row_size)
col = 0
o = 1
for hash in hashes:
image = images[hash]
observation = observations[hash].to_dict()
with grid[col]:
st.image(image, use_column_width=True)
st.markdown(metadata2md(hash, debug=True))
msg = f"[D] full observation after inference: {observation}"
g_logger.debug(msg)
print(msg)
# TODO: add a link to more info on the model, next to the button.
whale_classes = observations[hash].top_predictions
# render images for the top 3 (that is what the model api returns)
n = len(whale_classes)
st.markdown(f"**Top {n} Predictions for observation {str(o)}**")
for i in range(n):
viewer.display_whale(whale_classes, i)
o += 1
col = (col + 1) % row_size
|