detection-demo / app.py
kevinconka's picture
Add decode_blob_data function to handle Gradio image blob conversion in app.py
b3f650d
raw
history blame
8.06 kB
"""
Main application for RGB detection demo.
Any new model should implement the following functions:
- load_model(model_path, img_size=640)
- inference(model, image)
"""
import os
import glob
import hashlib
import struct
# import spaces
import gradio as gr
from huggingface_hub import get_token
from utils import (
check_image,
load_image_from_url,
load_badges,
FlaggedCounter,
)
from flagging import HuggingFaceDatasetSaver
import install_private_repos # noqa: F401
from seavision import load_model
TITLE = """
<h1> 🌊 SEA.AI's Vision Demo ✨ </h1>
<p align="center">
Ahoy! Explore our object detection technology!
Upload a maritime scene image and click <code>Submit</code>
to see the results.
</p>
"""
FLAG_TXT = "Report Mis-detection"
NOTICE = f"""
🚩 See something off? Your feedback makes a difference! Let us know by
flagging any outcomes that don't seem right. Click the `{FLAG_TXT}` button
to submit the image for review.
"""
css = """
h1 {
text-align: center;
display: block;
}
"""
model = load_model("experimental/ahoy6-MIX-1280-b1.onnx")
model.det_conf_thresh = 0.1
model.hor_conf_thresh = 0.1
# @spaces.GPU
def inference(image):
"""Run inference on image and return annotated image."""
results = model(image)
return results.draw(image)
def decode_blob_data(image_data):
"""
Decode blob data from Gradio image component.
Handles blob format and converts to proper image file format.
"""
if not isinstance(image_data, dict):
return image_data
print(f"DEBUG: Original input - image: {image_data}")
# Check if this is blob data - more comprehensive check
is_blob = (
'path' in image_data and
'blob' in image_data['path'] and
image_data.get('size') is None and
image_data.get('orig_name') is None and
image_data.get('mime_type') is None
)
if is_blob:
print(f"DEBUG: Converting blob data: {image_data}")
print("DEBUG: Detected blob format, converting...")
blob_path = image_data['path']
print(f"DEBUG: Blob path: {blob_path}")
# Read the blob file
with open(blob_path, 'rb') as f:
blob_content = f.read()
file_size = len(blob_content)
print(f"DEBUG: File size: {file_size}")
# Check file header to determine format
if len(blob_content) >= 8:
header = blob_content[:8].hex()
print(f"DEBUG: File header: {header}")
# PNG header: 89 50 4E 47 0D 0A 1A 0A
if header.startswith('89504e470d0a1a0a'):
extension = '.png'
mime_type = 'image/png'
# JPEG header: FF D8 FF
elif header.startswith('ffd8ff'):
extension = '.jpg'
mime_type = 'image/jpeg'
# GIF header: 47 49 46 38
elif header.startswith('47494638'):
extension = '.gif'
mime_type = 'image/gif'
else:
# Default to PNG if we can't determine
extension = '.png'
mime_type = 'image/png'
else:
extension = '.png'
mime_type = 'image/png'
print(f"DEBUG: Detected extension: {extension}, MIME type: {mime_type}")
# Generate a unique filename
content_hash = hashlib.md5(blob_content).hexdigest()[:8]
new_filename = f"flagged_image_{content_hash}{extension}"
print(f"DEBUG: Generated filename: {new_filename}")
# Create new path in the same directory
import tempfile
temp_dir = os.path.dirname(blob_path)
new_path = os.path.join(temp_dir, new_filename)
print(f"DEBUG: New path: {new_path}")
# Write the content to the new file
with open(new_path, 'wb') as f:
f.write(blob_content)
print(f"DEBUG: Successfully renamed blob to: {new_path}")
# Update the image data
converted_data = {
'path': new_path,
'url': image_data['url'].replace('blob', new_filename),
'size': file_size,
'orig_name': new_filename,
'mime_type': mime_type,
'is_stream': False,
'meta': image_data.get('meta', {})
}
print(f"DEBUG: Converted data: {converted_data}")
return converted_data
else:
print("DEBUG: Not a blob, skipping conversion")
print(f"DEBUG: Converted image: {image_data}")
return image_data
def flag_img_input(
image: gr.Image, flag_option: str = "misdetection", username: str = "anonymous"
):
"""Wrapper for flagging"""
print(f"{image=}, {flag_option=}, {username=}")
# Decode blob data if necessary
decoded_image = decode_blob_data(image)
hf_writer.flag([decoded_image], flag_option=flag_option, username=username)
# Flagging
dataset_name = "SEA-AI/crowdsourced-sea-images"
hf_writer = HuggingFaceDatasetSaver(get_token(), dataset_name)
flagged_counter = FlaggedCounter(dataset_name)
theme = gr.themes.Default(primary_hue=gr.themes.colors.indigo)
with gr.Blocks(theme=theme, css=css, title="SEA.AI Vision Demo") as demo:
badges = gr.HTML(load_badges(flagged_counter.count()))
title = gr.HTML(TITLE)
with gr.Row():
with gr.Column():
img_input = gr.Image(
label="input",
interactive=True,
sources=["upload", "clipboard"],
)
img_url = gr.Textbox(
lines=1,
placeholder="or enter URL to image here",
label="input_url",
show_label=False,
)
with gr.Row():
clear = gr.ClearButton()
submit = gr.Button("Submit", variant="primary")
with gr.Column():
img_output = gr.Image(label="output", interactive=False)
flag = gr.Button(FLAG_TXT, visible=False)
notice = gr.Markdown(value=NOTICE, visible=False)
examples = gr.Examples(
examples=glob.glob("examples/*.jpg"),
inputs=img_input,
outputs=img_output,
fn=inference,
cache_examples=True,
)
# add components to clear when clear button is clicked
clear.add([img_input, img_url, img_output])
# event listeners
img_url.change(load_image_from_url, [img_url], img_input)
submit.click(check_image, [img_input], None, show_api=False).success(
inference,
[img_input],
img_output,
api_name="inference",
)
# event listeners with decorators
@img_output.change(
inputs=[img_input, img_output],
outputs=[flag, notice],
show_api=False,
preprocess=False,
show_progress="hidden",
)
def _show_hide_flagging(_img_input, _img_output):
visible = _img_output and _img_input["orig_name"] not in os.listdir("examples")
return {
flag: gr.Button(FLAG_TXT, interactive=True, visible=visible),
notice: gr.Markdown(value=NOTICE, visible=visible),
}
# This needs to be called prior to the first call to callback.flag()
hf_writer.setup([img_input], "flagged")
# Sequential logic when flag button is clicked
flag.click(lambda: gr.Info("Thank you for contributing!"), show_api=False).then(
lambda: {flag: gr.Button(FLAG_TXT, interactive=False)},
[],
[flag],
show_api=False,
).then(
flag_img_input,
[img_input],
[],
preprocess=False,
show_api=True,
api_name="flag_misdetection",
).then(
lambda: load_badges(flagged_counter.count()),
[],
badges,
show_api=False,
)
# called during initial load in browser
demo.load(lambda: load_badges(flagged_counter.count()), [], badges, show_api=False)
if __name__ == "__main__":
demo.queue().launch()