""" Main application for RGB detection demo. Any new model should implement the following functions: - load_model(model_path, img_size=640) - inference(model, image) """ import os import glob import hashlib import struct # import spaces import gradio as gr from huggingface_hub import get_token from utils import ( check_image, load_image_from_url, load_badges, FlaggedCounter, ) from flagging import HuggingFaceDatasetSaver import install_private_repos # noqa: F401 from seavision import load_model TITLE = """

🌊 SEA.AI's Vision Demo ✨

Ahoy! Explore our object detection technology! Upload a maritime scene image and click Submit to see the results.

""" FLAG_TXT = "Report Mis-detection" NOTICE = f""" 🚩 See something off? Your feedback makes a difference! Let us know by flagging any outcomes that don't seem right. Click the `{FLAG_TXT}` button to submit the image for review. """ css = """ h1 { text-align: center; display: block; } """ model = load_model("experimental/ahoy6-MIX-1280-b1.onnx") model.det_conf_thresh = 0.1 model.hor_conf_thresh = 0.1 # @spaces.GPU def inference(image): """Run inference on image and return annotated image.""" results = model(image) return results.draw(image) def decode_blob_data(image_data): """ Decode blob data from Gradio image component. Handles blob format and converts to proper image file format. """ if not isinstance(image_data, dict): return image_data print(f"DEBUG: Original input - image: {image_data}") # Check if this is blob data - more comprehensive check is_blob = ( 'path' in image_data and 'blob' in image_data['path'] and image_data.get('size') is None and image_data.get('orig_name') is None and image_data.get('mime_type') is None ) if is_blob: print(f"DEBUG: Converting blob data: {image_data}") print("DEBUG: Detected blob format, converting...") blob_path = image_data['path'] print(f"DEBUG: Blob path: {blob_path}") # Read the blob file with open(blob_path, 'rb') as f: blob_content = f.read() file_size = len(blob_content) print(f"DEBUG: File size: {file_size}") # Check file header to determine format if len(blob_content) >= 8: header = blob_content[:8].hex() print(f"DEBUG: File header: {header}") # PNG header: 89 50 4E 47 0D 0A 1A 0A if header.startswith('89504e470d0a1a0a'): extension = '.png' mime_type = 'image/png' # JPEG header: FF D8 FF elif header.startswith('ffd8ff'): extension = '.jpg' mime_type = 'image/jpeg' # GIF header: 47 49 46 38 elif header.startswith('47494638'): extension = '.gif' mime_type = 'image/gif' else: # Default to PNG if we can't determine extension = '.png' mime_type = 'image/png' else: extension = '.png' mime_type = 'image/png' print(f"DEBUG: Detected extension: {extension}, MIME type: {mime_type}") # Generate a unique filename content_hash = hashlib.md5(blob_content).hexdigest()[:8] new_filename = f"flagged_image_{content_hash}{extension}" print(f"DEBUG: Generated filename: {new_filename}") # Create new path in the same directory import tempfile temp_dir = os.path.dirname(blob_path) new_path = os.path.join(temp_dir, new_filename) print(f"DEBUG: New path: {new_path}") # Write the content to the new file with open(new_path, 'wb') as f: f.write(blob_content) print(f"DEBUG: Successfully renamed blob to: {new_path}") # Update the image data converted_data = { 'path': new_path, 'url': image_data['url'].replace('blob', new_filename), 'size': file_size, 'orig_name': new_filename, 'mime_type': mime_type, 'is_stream': False, 'meta': image_data.get('meta', {}) } print(f"DEBUG: Converted data: {converted_data}") return converted_data else: print("DEBUG: Not a blob, skipping conversion") print(f"DEBUG: Converted image: {image_data}") return image_data def flag_img_input( image: gr.Image, flag_option: str = "misdetection", username: str = "anonymous" ): """Wrapper for flagging""" print(f"{image=}, {flag_option=}, {username=}") # Decode blob data if necessary decoded_image = decode_blob_data(image) hf_writer.flag([decoded_image], flag_option=flag_option, username=username) # Flagging dataset_name = "SEA-AI/crowdsourced-sea-images" hf_writer = HuggingFaceDatasetSaver(get_token(), dataset_name) flagged_counter = FlaggedCounter(dataset_name) theme = gr.themes.Default(primary_hue=gr.themes.colors.indigo) with gr.Blocks(theme=theme, css=css, title="SEA.AI Vision Demo") as demo: badges = gr.HTML(load_badges(flagged_counter.count())) title = gr.HTML(TITLE) with gr.Row(): with gr.Column(): img_input = gr.Image( label="input", interactive=True, sources=["upload", "clipboard"], ) img_url = gr.Textbox( lines=1, placeholder="or enter URL to image here", label="input_url", show_label=False, ) with gr.Row(): clear = gr.ClearButton() submit = gr.Button("Submit", variant="primary") with gr.Column(): img_output = gr.Image(label="output", interactive=False) flag = gr.Button(FLAG_TXT, visible=False) notice = gr.Markdown(value=NOTICE, visible=False) examples = gr.Examples( examples=glob.glob("examples/*.jpg"), inputs=img_input, outputs=img_output, fn=inference, cache_examples=True, ) # add components to clear when clear button is clicked clear.add([img_input, img_url, img_output]) # event listeners img_url.change(load_image_from_url, [img_url], img_input) submit.click(check_image, [img_input], None, show_api=False).success( inference, [img_input], img_output, api_name="inference", ) # event listeners with decorators @img_output.change( inputs=[img_input, img_output], outputs=[flag, notice], show_api=False, preprocess=False, show_progress="hidden", ) def _show_hide_flagging(_img_input, _img_output): visible = _img_output and _img_input["orig_name"] not in os.listdir("examples") return { flag: gr.Button(FLAG_TXT, interactive=True, visible=visible), notice: gr.Markdown(value=NOTICE, visible=visible), } # This needs to be called prior to the first call to callback.flag() hf_writer.setup([img_input], "flagged") # Sequential logic when flag button is clicked flag.click(lambda: gr.Info("Thank you for contributing!"), show_api=False).then( lambda: {flag: gr.Button(FLAG_TXT, interactive=False)}, [], [flag], show_api=False, ).then( flag_img_input, [img_input], [], preprocess=False, show_api=True, api_name="flag_misdetection", ).then( lambda: load_badges(flagged_counter.count()), [], badges, show_api=False, ) # called during initial load in browser demo.load(lambda: load_badges(flagged_counter.count()), [], badges, show_api=False) if __name__ == "__main__": demo.queue().launch()