detection-demo / app.py
kevinconka's picture
Update flagging functionality in app.py and flagging.py to use metadata for user information and change dataset name to versioned format
f4ef5b4
raw
history blame
5.34 kB
"""
Main application for RGB detection demo.
Any new model should implement the following functions:
- load_model(model_path, img_size=640)
- inference(model, image)
"""
import os
import glob
# import spaces
import gradio as gr
from huggingface_hub import get_token
from utils import (
check_image,
load_image_from_url,
load_badges,
FlaggedCounter,
)
from flagging import HuggingFaceDatasetSaver
from blob_utils import decode_blob_data, is_blob_data
from logging_config import get_logger
import install_private_repos # noqa: F401
from seavision import load_model
# Get loggers
logger = get_logger(__name__)
TITLE = """
<h1> 🌊 SEA.AI's Vision Demo ✨ </h1>
<p align="center">
Ahoy! Explore our object detection technology!
Upload a maritime scene image and click <code>Submit</code>
to see the results.
</p>
"""
FLAG_TXT = "Report Mis-detection"
NOTICE = f"""
🚩 See something off? Your feedback makes a difference! Let us know by
flagging any outcomes that don't seem right. Click the `{FLAG_TXT}` button
to submit the image for review.
"""
css = """
h1 {
text-align: center;
display: block;
}
"""
# Load model
logger.info("Loading detection model...")
model = load_model("experimental/ahoy6-MIX-1280-b1.onnx")
model.det_conf_thresh = 0.1
model.hor_conf_thresh = 0.1
# @spaces.GPU
def inference(image):
"""Run inference on image and return annotated image."""
logger.debug("Running inference on image")
results = model(image)
logger.debug("Inference completed")
return results.draw(image)
def flag_img_input(
image: gr.Image, name: str = "misdetection", email: str = "[email protected]"
):
"""Wrapper for flagging"""
logger.info("Flagging image - name: %s, email: %s", name, email)
# Decode blob data if necessary
if is_blob_data(image):
image = decode_blob_data(image)
metadata = {
"name": name,
"email": email,
}
hf_writer.flag([image], metadata=metadata)
logger.info("Image flagged successfully")
# Flagging
dataset_name = "SEA-AI/crowdsourced-sea-images-v2"
hf_writer = HuggingFaceDatasetSaver(get_token(), dataset_name)
flagged_counter = FlaggedCounter(dataset_name)
theme = gr.themes.Default(primary_hue=gr.themes.colors.indigo)
with gr.Blocks(theme=theme, css=css, title="SEA.AI Vision Demo") as demo:
badges = gr.HTML(load_badges(flagged_counter.count()))
title = gr.HTML(TITLE)
with gr.Row():
with gr.Column():
img_input = gr.Image(
label="input",
interactive=True,
sources=["upload", "clipboard"],
)
img_url = gr.Textbox(
lines=1,
placeholder="or enter URL to image here",
label="input_url",
show_label=False,
)
with gr.Row():
clear = gr.ClearButton()
submit = gr.Button("Submit", variant="primary")
with gr.Column():
img_output = gr.Image(label="output", interactive=False)
flag = gr.Button(FLAG_TXT, visible=False)
notice = gr.Markdown(value=NOTICE, visible=False)
examples = gr.Examples(
examples=glob.glob("examples/*.jpg"),
inputs=img_input,
outputs=img_output,
fn=inference,
cache_examples=True,
)
# add components to clear when clear button is clicked
clear.add([img_input, img_url, img_output])
# event listeners
img_url.change(load_image_from_url, [img_url], img_input)
submit.click(check_image, [img_input], None, show_api=False).success(
inference,
[img_input],
img_output,
api_name="inference",
)
# event listeners with decorators
@img_output.change(
inputs=[img_input, img_output],
outputs=[flag, notice],
show_api=False,
preprocess=False,
show_progress="hidden",
)
def _show_hide_flagging(_img_input, _img_output):
visible = _img_output and _img_input["orig_name"] not in os.listdir("examples")
return {
flag: gr.Button(FLAG_TXT, interactive=True, visible=visible),
notice: gr.Markdown(value=NOTICE, visible=visible),
}
# add hidden textbox for name and email (hacky but well...)
name = gr.Textbox(label="name", visible=False, value="anonymous")
email = gr.Textbox(label="email", visible=False, value="[email protected]")
# This needs to be called prior to the first call to callback.flag()
hf_writer.setup([img_input], "flagged")
# Sequential logic when flag button is clicked
flag.click(lambda: gr.Info("Thank you for contributing!"), show_api=False).then(
lambda: {flag: gr.Button(FLAG_TXT, interactive=False)},
[],
[flag],
show_api=False,
).then(
flag_img_input,
[img_input, name, email],
[],
preprocess=False,
show_api=True,
api_name="flag_misdetection",
).then(
lambda: load_badges(flagged_counter.count()),
[],
badges,
show_api=False,
)
# called during initial load in browser
demo.load(lambda: load_badges(flagged_counter.count()), [], badges, show_api=False)
if __name__ == "__main__":
demo.queue().launch()