"""
Main application for RGB detection demo.
Any new model should implement the following functions:
- load_model(model_path, img_size=640)
- inference(model, image)
"""
import os
import glob
import spaces
import gradio as gr
from huggingface_hub import get_token
from utils import (
    check_image,
    load_image_from_url,
    load_badges,
    FlaggedCounter,
)
from flagging import HuggingFaceDatasetSaver
import install_private_repos
from seavision import load_model, AHOY


TITLE = """
<h1> 🌊 SEA.AI's Machine Vision Demo ✨ </h1>
<p align="center">
Ahoy! Explore our object detection technology!
Upload a maritime scene image and click <code>Submit</code>
to see the results.
</p>
"""

FLAG_TXT = "Report Mis-detection"

NOTICE = f"""
🚩 See something off? Your feedback makes a difference! Let us know by
flagging any outcomes that don't seem right. Click the `{FLAG_TXT}` button
to submit the image for review.
"""

css = """
h1 {
    text-align: center;
    display: block;
}
"""

model = load_model("ahoy-RGB-b2")

@spaces.GPU
def inference(model: AHOY, image):
    """Run inference on image and return annotated image."""
    results = model(image)
    return results.draw(image, diameter=4)

# Flagging
dataset_name = "SEA-AI/crowdsourced-sea-images"
hf_writer = HuggingFaceDatasetSaver(get_token(), dataset_name)
flagged_counter = FlaggedCounter(dataset_name)


theme = gr.themes.Default(primary_hue=gr.themes.colors.indigo)
with gr.Blocks(theme=theme, css=css, title="SEA.AI Vision Demo") as demo:
    badges = gr.HTML(load_badges(flagged_counter.count()))
    title = gr.HTML(TITLE)

    with gr.Row():
        with gr.Column():
            img_input = gr.Image(
                label="input", interactive=True, sources=["upload", "clipboard"]
            )
            img_url = gr.Textbox(
                lines=1,
                placeholder="or enter URL to image here",
                label="input_url",
                show_label=False,
            )
            with gr.Row():
                clear = gr.ClearButton()
                submit = gr.Button("Submit", variant="primary")
        with gr.Column():
            img_output = gr.Image(label="output", interactive=False)
            flag = gr.Button(FLAG_TXT, visible=False)
            notice = gr.Markdown(value=NOTICE, visible=False)

    examples = gr.Examples(
        examples=glob.glob("examples/*.jpg"),
        inputs=img_input,
        outputs=img_output,
        fn=lambda image: inference(model, image),
        cache_examples=True,
    )

    # add components to clear when clear button is clicked
    clear.add([img_input, img_url, img_output])

    # event listeners
    img_url.change(load_image_from_url, [img_url], img_input)
    submit.click(check_image, [img_input], show_api=False).success(
        lambda image: inference(model, image),
        [img_input],
        img_output,
        api_name="inference",
    )

    # event listeners with decorators
    @img_output.change(
        inputs=[img_input, img_output],
        outputs=[flag, notice],
        show_api=False,
        preprocess=False,
        show_progress="hidden",
    )
    def _show_hide_flagging(_img_input, _img_output):
        visible = _img_output and _img_input["orig_name"] not in os.listdir("examples")
        return {
            flag: gr.Button(FLAG_TXT, interactive=True, visible=visible),
            notice: gr.Markdown(value=NOTICE, visible=visible),
        }

    # This needs to be called prior to the first call to callback.flag()
    hf_writer.setup([img_input], "flagged")

    # Sequential logic when flag button is clicked
    flag.click(lambda: gr.Info("Thank you for contributing!"), show_api=False).then(
        lambda: {flag: gr.Button(FLAG_TXT, interactive=False)},
        [],
        [flag],
        show_api=False,
    ).then(
        lambda *args: hf_writer.flag(args),
        [img_input, flag],
        [],
        preprocess=False,
        show_api=False,
    ).then(
        lambda: load_badges(flagged_counter.count()), [], badges, show_api=False
    )

    # called during initial load in browser
    demo.load(lambda: load_badges(flagged_counter.count()), [], badges, show_api=False)

if __name__ == "__main__":
    demo.queue().launch()