Spaces:
Sleeping
Sleeping
import streamlit as st | |
import logging | |
import os | |
# get a global var for logger accessor in this module | |
LOG_LEVEL = logging.DEBUG | |
g_logger = logging.getLogger(__name__) | |
g_logger.setLevel(LOG_LEVEL) | |
from grid_maker import gridder | |
import hf_push_observations as sw_push_obs | |
import utils.metadata_handler as meta_handler | |
import whale_viewer as sw_wv | |
def cetacean_classify(cetacean_classifier, tab_inference): | |
files = st.session_state.files | |
images = st.session_state.images | |
observations = st.session_state.observations | |
batch_size, row_size, page = gridder(files) | |
grid = st.columns(row_size) | |
col = 0 | |
for file in files: | |
image = images[file.name] | |
with grid[col]: | |
st.image(image, use_column_width=True) | |
observation = observations[file.name].to_dict() | |
# run classifier model on `image`, and persistently store the output | |
out = cetacean_classifier(image) # get top 3 matches | |
st.session_state.whale_prediction1 = out['predictions'][0] | |
st.session_state.classify_whale_done = True | |
msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}" | |
g_logger.info(msg) | |
# dropdown for selecting/overriding the species prediction | |
if not st.session_state.classify_whale_done: | |
selected_class = st.sidebar.selectbox("Species", sw_wv.WHALE_CLASSES, | |
index=None, placeholder="Species not yet identified...", | |
disabled=True) | |
else: | |
pred1 = st.session_state.whale_prediction1 | |
# get index of pred1 from WHALE_CLASSES, none if not present | |
print(f"[D] pred1: {pred1}") | |
ix = sw_wv.WHALE_CLASSES.index(pred1) if pred1 in sw_wv.WHALE_CLASSES else None | |
selected_class = tab_inference.selectbox("Species", sw_wv.WHALE_CLASSES, index=ix) | |
observation['predicted_class'] = selected_class | |
if selected_class != st.session_state.whale_prediction1: | |
observation['class_overriden'] = selected_class | |
st.session_state.public_observation = observation | |
st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=sw_push_obs.push_observations) | |
# TODO: the metadata only fills properly if `validate` was clicked. | |
st.markdown(meta_handler.metadata2md()) | |
msg = f"[D] full observation after inference: {observation}" | |
g_logger.debug(msg) | |
print(msg) | |
# TODO: add a link to more info on the model, next to the button. | |
whale_classes = out['predictions'][:] | |
# render images for the top 3 (that is what the model api returns) | |
#with tab_inference: | |
st.title(f"Species detected for {file.name}") | |
for i in range(len(whale_classes)): | |
sw_wv.display_whale(whale_classes, i) | |
col = (col + 1) % row_size |