"""
Main application for RGB detection demo.
Any new model should implement the following functions:
- load_model(model_path, img_size=640)
- inference(model, image)
"""
import os
import glob
import hashlib
import struct
# import spaces
import gradio as gr
from huggingface_hub import get_token
from utils import (
check_image,
load_image_from_url,
load_badges,
FlaggedCounter,
)
from flagging import HuggingFaceDatasetSaver
import install_private_repos # noqa: F401
from seavision import load_model
TITLE = """
🌊 SEA.AI's Vision Demo ✨
Ahoy! Explore our object detection technology!
Upload a maritime scene image and click Submit
to see the results.
"""
FLAG_TXT = "Report Mis-detection"
NOTICE = f"""
🚩 See something off? Your feedback makes a difference! Let us know by
flagging any outcomes that don't seem right. Click the `{FLAG_TXT}` button
to submit the image for review.
"""
css = """
h1 {
text-align: center;
display: block;
}
"""
model = load_model("experimental/ahoy6-MIX-1280-b1.onnx")
model.det_conf_thresh = 0.1
model.hor_conf_thresh = 0.1
# @spaces.GPU
def inference(image):
"""Run inference on image and return annotated image."""
results = model(image)
return results.draw(image)
def decode_blob_data(image_data):
"""
Decode blob data from Gradio image component.
Handles blob format and converts to proper image file format.
"""
if not isinstance(image_data, dict):
return image_data
print(f"DEBUG: Original input - image: {image_data}")
# Check if this is blob data - more comprehensive check
is_blob = (
'path' in image_data and
'blob' in image_data['path'] and
image_data.get('size') is None and
image_data.get('orig_name') is None and
image_data.get('mime_type') is None
)
if is_blob:
print(f"DEBUG: Converting blob data: {image_data}")
print("DEBUG: Detected blob format, converting...")
blob_path = image_data['path']
print(f"DEBUG: Blob path: {blob_path}")
# Read the blob file
with open(blob_path, 'rb') as f:
blob_content = f.read()
file_size = len(blob_content)
print(f"DEBUG: File size: {file_size}")
# Check file header to determine format
if len(blob_content) >= 8:
header = blob_content[:8].hex()
print(f"DEBUG: File header: {header}")
# PNG header: 89 50 4E 47 0D 0A 1A 0A
if header.startswith('89504e470d0a1a0a'):
extension = '.png'
mime_type = 'image/png'
# JPEG header: FF D8 FF
elif header.startswith('ffd8ff'):
extension = '.jpg'
mime_type = 'image/jpeg'
# GIF header: 47 49 46 38
elif header.startswith('47494638'):
extension = '.gif'
mime_type = 'image/gif'
else:
# Default to PNG if we can't determine
extension = '.png'
mime_type = 'image/png'
else:
extension = '.png'
mime_type = 'image/png'
print(f"DEBUG: Detected extension: {extension}, MIME type: {mime_type}")
# Generate a unique filename
content_hash = hashlib.md5(blob_content).hexdigest()[:8]
new_filename = f"flagged_image_{content_hash}{extension}"
print(f"DEBUG: Generated filename: {new_filename}")
# Create new path in the same directory
import tempfile
temp_dir = os.path.dirname(blob_path)
new_path = os.path.join(temp_dir, new_filename)
print(f"DEBUG: New path: {new_path}")
# Write the content to the new file
with open(new_path, 'wb') as f:
f.write(blob_content)
print(f"DEBUG: Successfully renamed blob to: {new_path}")
# Update the image data
converted_data = {
'path': new_path,
'url': image_data['url'].replace('blob', new_filename),
'size': file_size,
'orig_name': new_filename,
'mime_type': mime_type,
'is_stream': False,
'meta': image_data.get('meta', {})
}
print(f"DEBUG: Converted data: {converted_data}")
return converted_data
else:
print("DEBUG: Not a blob, skipping conversion")
print(f"DEBUG: Converted image: {image_data}")
return image_data
def flag_img_input(
image: gr.Image, flag_option: str = "misdetection", username: str = "anonymous"
):
"""Wrapper for flagging"""
print(f"{image=}, {flag_option=}, {username=}")
# Decode blob data if necessary
decoded_image = decode_blob_data(image)
hf_writer.flag([decoded_image], flag_option=flag_option, username=username)
# Flagging
dataset_name = "SEA-AI/crowdsourced-sea-images"
hf_writer = HuggingFaceDatasetSaver(get_token(), dataset_name)
flagged_counter = FlaggedCounter(dataset_name)
theme = gr.themes.Default(primary_hue=gr.themes.colors.indigo)
with gr.Blocks(theme=theme, css=css, title="SEA.AI Vision Demo") as demo:
badges = gr.HTML(load_badges(flagged_counter.count()))
title = gr.HTML(TITLE)
with gr.Row():
with gr.Column():
img_input = gr.Image(
label="input",
interactive=True,
sources=["upload", "clipboard"],
)
img_url = gr.Textbox(
lines=1,
placeholder="or enter URL to image here",
label="input_url",
show_label=False,
)
with gr.Row():
clear = gr.ClearButton()
submit = gr.Button("Submit", variant="primary")
with gr.Column():
img_output = gr.Image(label="output", interactive=False)
flag = gr.Button(FLAG_TXT, visible=False)
notice = gr.Markdown(value=NOTICE, visible=False)
examples = gr.Examples(
examples=glob.glob("examples/*.jpg"),
inputs=img_input,
outputs=img_output,
fn=inference,
cache_examples=True,
)
# add components to clear when clear button is clicked
clear.add([img_input, img_url, img_output])
# event listeners
img_url.change(load_image_from_url, [img_url], img_input)
submit.click(check_image, [img_input], None, show_api=False).success(
inference,
[img_input],
img_output,
api_name="inference",
)
# event listeners with decorators
@img_output.change(
inputs=[img_input, img_output],
outputs=[flag, notice],
show_api=False,
preprocess=False,
show_progress="hidden",
)
def _show_hide_flagging(_img_input, _img_output):
visible = _img_output and _img_input["orig_name"] not in os.listdir("examples")
return {
flag: gr.Button(FLAG_TXT, interactive=True, visible=visible),
notice: gr.Markdown(value=NOTICE, visible=visible),
}
# This needs to be called prior to the first call to callback.flag()
hf_writer.setup([img_input], "flagged")
# Sequential logic when flag button is clicked
flag.click(lambda: gr.Info("Thank you for contributing!"), show_api=False).then(
lambda: {flag: gr.Button(FLAG_TXT, interactive=False)},
[],
[flag],
show_api=False,
).then(
flag_img_input,
[img_input],
[],
preprocess=False,
show_api=True,
api_name="flag_misdetection",
).then(
lambda: load_badges(flagged_counter.count()),
[],
badges,
show_api=False,
)
# called during initial load in browser
demo.load(lambda: load_badges(flagged_counter.count()), [], badges, show_api=False)
if __name__ == "__main__":
demo.queue().launch()