Spaces:
Running
Running
""" | |
Main application for RGB detection demo. | |
Any new model should implement the following functions: | |
- load_model(model_path, img_size=640) | |
- inference(model, image) | |
""" | |
import os | |
import glob | |
import hashlib | |
import struct | |
# import spaces | |
import gradio as gr | |
from huggingface_hub import get_token | |
from utils import ( | |
check_image, | |
load_image_from_url, | |
load_badges, | |
FlaggedCounter, | |
) | |
from flagging import HuggingFaceDatasetSaver | |
import install_private_repos # noqa: F401 | |
from seavision import load_model | |
TITLE = """ | |
<h1> π SEA.AI's Vision Demo β¨ </h1> | |
<p align="center"> | |
Ahoy! Explore our object detection technology! | |
Upload a maritime scene image and click <code>Submit</code> | |
to see the results. | |
</p> | |
""" | |
FLAG_TXT = "Report Mis-detection" | |
NOTICE = f""" | |
π© See something off? Your feedback makes a difference! Let us know by | |
flagging any outcomes that don't seem right. Click the `{FLAG_TXT}` button | |
to submit the image for review. | |
""" | |
css = """ | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
""" | |
model = load_model("experimental/ahoy6-MIX-1280-b1.onnx") | |
model.det_conf_thresh = 0.1 | |
model.hor_conf_thresh = 0.1 | |
# @spaces.GPU | |
def inference(image): | |
"""Run inference on image and return annotated image.""" | |
results = model(image) | |
return results.draw(image) | |
def decode_blob_data(image_data): | |
""" | |
Decode blob data from Gradio image component. | |
Handles blob format and converts to proper image file format. | |
""" | |
if not isinstance(image_data, dict): | |
return image_data | |
print(f"DEBUG: Original input - image: {image_data}") | |
# Check if this is blob data - more comprehensive check | |
is_blob = ( | |
'path' in image_data and | |
'blob' in image_data['path'] and | |
image_data.get('size') is None and | |
image_data.get('orig_name') is None and | |
image_data.get('mime_type') is None | |
) | |
if is_blob: | |
print(f"DEBUG: Converting blob data: {image_data}") | |
print("DEBUG: Detected blob format, converting...") | |
blob_path = image_data['path'] | |
print(f"DEBUG: Blob path: {blob_path}") | |
# Read the blob file | |
with open(blob_path, 'rb') as f: | |
blob_content = f.read() | |
file_size = len(blob_content) | |
print(f"DEBUG: File size: {file_size}") | |
# Check file header to determine format | |
if len(blob_content) >= 8: | |
header = blob_content[:8].hex() | |
print(f"DEBUG: File header: {header}") | |
# PNG header: 89 50 4E 47 0D 0A 1A 0A | |
if header.startswith('89504e470d0a1a0a'): | |
extension = '.png' | |
mime_type = 'image/png' | |
# JPEG header: FF D8 FF | |
elif header.startswith('ffd8ff'): | |
extension = '.jpg' | |
mime_type = 'image/jpeg' | |
# GIF header: 47 49 46 38 | |
elif header.startswith('47494638'): | |
extension = '.gif' | |
mime_type = 'image/gif' | |
else: | |
# Default to PNG if we can't determine | |
extension = '.png' | |
mime_type = 'image/png' | |
else: | |
extension = '.png' | |
mime_type = 'image/png' | |
print(f"DEBUG: Detected extension: {extension}, MIME type: {mime_type}") | |
# Generate a unique filename | |
content_hash = hashlib.md5(blob_content).hexdigest()[:8] | |
new_filename = f"flagged_image_{content_hash}{extension}" | |
print(f"DEBUG: Generated filename: {new_filename}") | |
# Create new path in the same directory | |
import tempfile | |
temp_dir = os.path.dirname(blob_path) | |
new_path = os.path.join(temp_dir, new_filename) | |
print(f"DEBUG: New path: {new_path}") | |
# Write the content to the new file | |
with open(new_path, 'wb') as f: | |
f.write(blob_content) | |
print(f"DEBUG: Successfully renamed blob to: {new_path}") | |
# Update the image data | |
converted_data = { | |
'path': new_path, | |
'url': image_data['url'].replace('blob', new_filename), | |
'size': file_size, | |
'orig_name': new_filename, | |
'mime_type': mime_type, | |
'is_stream': False, | |
'meta': image_data.get('meta', {}) | |
} | |
print(f"DEBUG: Converted data: {converted_data}") | |
return converted_data | |
else: | |
print("DEBUG: Not a blob, skipping conversion") | |
print(f"DEBUG: Converted image: {image_data}") | |
return image_data | |
def flag_img_input( | |
image: gr.Image, flag_option: str = "misdetection", username: str = "anonymous" | |
): | |
"""Wrapper for flagging""" | |
print(f"{image=}, {flag_option=}, {username=}") | |
# Decode blob data if necessary | |
decoded_image = decode_blob_data(image) | |
hf_writer.flag([decoded_image], flag_option=flag_option, username=username) | |
# Flagging | |
dataset_name = "SEA-AI/crowdsourced-sea-images" | |
hf_writer = HuggingFaceDatasetSaver(get_token(), dataset_name) | |
flagged_counter = FlaggedCounter(dataset_name) | |
theme = gr.themes.Default(primary_hue=gr.themes.colors.indigo) | |
with gr.Blocks(theme=theme, css=css, title="SEA.AI Vision Demo") as demo: | |
badges = gr.HTML(load_badges(flagged_counter.count())) | |
title = gr.HTML(TITLE) | |
with gr.Row(): | |
with gr.Column(): | |
img_input = gr.Image( | |
label="input", | |
interactive=True, | |
sources=["upload", "clipboard"], | |
) | |
img_url = gr.Textbox( | |
lines=1, | |
placeholder="or enter URL to image here", | |
label="input_url", | |
show_label=False, | |
) | |
with gr.Row(): | |
clear = gr.ClearButton() | |
submit = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
img_output = gr.Image(label="output", interactive=False) | |
flag = gr.Button(FLAG_TXT, visible=False) | |
notice = gr.Markdown(value=NOTICE, visible=False) | |
examples = gr.Examples( | |
examples=glob.glob("examples/*.jpg"), | |
inputs=img_input, | |
outputs=img_output, | |
fn=inference, | |
cache_examples=True, | |
) | |
# add components to clear when clear button is clicked | |
clear.add([img_input, img_url, img_output]) | |
# event listeners | |
img_url.change(load_image_from_url, [img_url], img_input) | |
submit.click(check_image, [img_input], None, show_api=False).success( | |
inference, | |
[img_input], | |
img_output, | |
api_name="inference", | |
) | |
# event listeners with decorators | |
def _show_hide_flagging(_img_input, _img_output): | |
visible = _img_output and _img_input["orig_name"] not in os.listdir("examples") | |
return { | |
flag: gr.Button(FLAG_TXT, interactive=True, visible=visible), | |
notice: gr.Markdown(value=NOTICE, visible=visible), | |
} | |
# This needs to be called prior to the first call to callback.flag() | |
hf_writer.setup([img_input], "flagged") | |
# Sequential logic when flag button is clicked | |
flag.click(lambda: gr.Info("Thank you for contributing!"), show_api=False).then( | |
lambda: {flag: gr.Button(FLAG_TXT, interactive=False)}, | |
[], | |
[flag], | |
show_api=False, | |
).then( | |
flag_img_input, | |
[img_input], | |
[], | |
preprocess=False, | |
show_api=True, | |
api_name="flag_misdetection", | |
).then( | |
lambda: load_badges(flagged_counter.count()), | |
[], | |
badges, | |
show_api=False, | |
) | |
# called during initial load in browser | |
demo.load(lambda: load_badges(flagged_counter.count()), [], badges, show_api=False) | |
if __name__ == "__main__": | |
demo.queue().launch() | |