import gradio as gr

from plotting import create_yolobench_plots, get_pareto_table, create_comparison_plot
from utils import DEEPLITE_DARK_BLUE_GRADIO


def get_hw_description(hw_name):
    HW_URLS = {
        'Jetson Nano (GPU, ONNX Runtime, FP32)': 'https://8074457.fs1.hubspotusercontent-na1.net/hubfs/8074457/YOLOBench%20Hardware%20product%20sheets/JetsonNano_DataSheet_DS09366001v1.1.pdf',
        'Raspberry Pi 4 Model B (CPU, TFLite, FP32)': 'https://8074457.fs1.hubspotusercontent-na1.net/hubfs/8074457/YOLOBench%20Hardware%20product%20sheets/raspberry-pi-4-datasheet.pdf',
        'Raspberry Pi 4 Model B (CPU, ONNX Runtime, FP32)': 'https://8074457.fs1.hubspotusercontent-na1.net/hubfs/8074457/YOLOBench%20Hardware%20product%20sheets/raspberry-pi-4-datasheet.pdf',
        'Raspberry Pi 5 Model B (CPU, ONNX Runtime, FP32)': 'https://8074457.fs1.hubspotusercontent-na1.net/hubfs/8074457/YOLOBench%20Assets/Hardware%20Product%20Assets/raspberry-pi-5-product-brief.pdf',
        'Intel® Core™i7-10875H (CPU, OpenVINO, FP32)': 'https://8074457.fs1.hubspotusercontent-na1.net/hubfs/8074457/YOLOBench%20Hardware%20product%20sheets/Intel_ARK_SpecificationsChart_2023_10_11.pdf',
        'Khadas VIM3 (NPU, INT16)': 'https://8074457.fs1.hubspotusercontent-na1.net/hubfs/8074457/YOLOBench%20Hardware%20product%20sheets/khadas_vim3_specs.pdf',
        'Orange Pi 5 (NPU, FP16)': 'https://8074457.fs1.hubspotusercontent-na1.net/hubfs/8074457/YOLOBench%20Hardware%20product%20sheets/OrangePi_5_RK3588S_User%20Manual_v1.5.pdf',
        'NVIDIA A40 (GPU, TensorRT, FP32)': 'https://images.nvidia.com/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf',
    }

    hw_url = HW_URLS[hw_name]
    DESC = f"""
    🔸 <span style="font-size:16px">Click </span>[<span style="font-size:16px">here</span>]({hw_url})<span style="font-size:16px"> for more information on the selected hardware platform.</span>
    🔸 <span style="font-size:16px">Refer to the [Deeplite Torch Zoo](https://github.com/Deeplite/deeplite-torch-zoo/tree/develop/results/yolobench) for details about latency measurement experiments.</span>
    """
    return DESC


with gr.Blocks(
    theme=gr.themes.Default(secondary_hue=DEEPLITE_DARK_BLUE_GRADIO),
    css="table { width: 100%; }",
    analytics_enabled=True,
) as demo:
    gr.HTML(
        """
        <div align="center">
        <img src="file/banner.png"/>
        </div>
        """
    )

    # switch to light theme by default
    demo.load(
        None,
        _js="""
        () => {
            let mediaQueryObj = window.matchMedia('(prefers-color-scheme: dark)');
            let systemDarkTheme = window.location.href.includes("theme=system") && mediaQueryObj.matches;
            if (window.location.href.includes("theme=dark") || systemDarkTheme){
                document.body.classList.toggle('dark');
                document.querySelector('gradio-app').style.backgroundColor = 'var(--color-background-primary)'
            }
        }
        """,
    )

    demo.load(
        None,
        _js="""
        () => {
            const script2 = document.createElement("script");
            script2.src = "https://www.googletagmanager.com/gtag/js?id=G-01G83VTHE0";
            script2.async = true;
            document.head.appendChild(script2);
            window.dataLayer = window.dataLayer || [];
            function gtag(){dataLayer.push(arguments);}
            gtag('js', new Date());
            gtag('config', 'G-01G83VTHE0', {
                'page_path': "/spaces/deepliteai/yolobench",
                'page_title': 'yolobench',
                'cookie_flags': 'SameSite=None;Secure',
                'debug_mode':true,
            });
        }
        """,
    )

    with gr.Row():
        gr.Markdown(
            """
        <span style="font-size:16px">

        🚀 <b>YOLOBench</b> 🚀 is a latency-accuracy benchmark of popular single-stage detectors from the YOLO series. Major highlights of this work are:

        🔸 includes architectures from YOLOv3 to YOLOv8, <br>
        🔸 trained on <span style="font-weight:bold">four</span> popular object detection datasets (COCO, VOC, WIDER FACE, SKU-110k), <br>
        🔸 latency measured on <span style="font-weight:bold">a growing list of hardware platforms</span> (examples include Jetson Nano GPU, ARM CPU, Intel CPU, Khadas VIM3 NPU, Orange Pi NPU), <br>
        🔸 all models are trained with <span style="font-weight:bold">the same</span> training loop and hyperparameters (as implemented in the [Ultralytics YOLOv8 codebase](https://github.com/ultralytics/ultralytics)), <br>
        🔸 both <span style="font-weight:bold">the detection head structure</span> and <span style="font-weight:bold"> the loss function </span> used are that of YOLOv8, giving a chance to isolate the contribution of the backbone/neck architecture on the latency-accuracy trade-off of YOLO models. <br>
        In particular, we show that older backbone/neck structures like those of YOLOv3 and YOLOv4 are still competitive compared to more recent architectures in a controlled environment. For more details, please refer to the [arXiv preprint](https://arxiv.org/abs/2307.13901) and the [codebase](https://github.com/Deeplite/deeplite-torch-zoo).

        #

        </span>

        #
        """
        )

    with gr.Tab("YOLO model comparision"):
        with gr.Row(equal_height=True):
            with gr.Column():
                hardware_name = gr.Dropdown(
                    choices=[
                        'Jetson Nano (GPU, ONNX Runtime, FP32)',
                        'Raspberry Pi 4 Model B (CPU, TFLite, FP32)',
                        'Raspberry Pi 4 Model B (CPU, ONNX Runtime, FP32)',
                        'Raspberry Pi 5 Model B (CPU, ONNX Runtime, FP32)',
                        'Intel® Core™i7-10875H (CPU, OpenVINO, FP32)',
                        'Khadas VIM3 (NPU, INT16)',
                        'Orange Pi 5 (NPU, FP16)',
                        'NVIDIA A40 (GPU, TensorRT, FP32)',
                    ],
                    value='Jetson Nano (GPU, ONNX Runtime, FP32)',
                    label='Hardware target',
                )
            with gr.Column():
                dataset_name = gr.Dropdown(
                    choices=['COCO', 'PASCAL VOC', 'SKU-110K', 'WIDERFACE'],
                    value='COCO',
                    label='Dataset',
                )

        with gr.Row(equal_height=True):
            with gr.Column():
                hardware_desc = gr.Markdown(get_hw_description(hardware_name.value))

            with gr.Column():
                metric_name = gr.Radio(
                    ['mAP@0.5:0.95', 'mAP@0.5', 'Precision', 'Recall'],
                    value='mAP@0.5:0.95',
                    label='Accuracy metric to plot',
                )

        with gr.Row(equal_height=True):
            with gr.Column():
                gr.Markdown(
                    """
                            <span style="font-size:16px">

                            🚀 <span style="font-weight:bold">Want to add your own hardware benchmarks to YOLOBench?</span> 🚀
                            Contact us [here](https://info.deeplite.ai/add_yolobench_data) for your benchmarking kit and we'll set you up!

                            </span>
                            """
                )

            with gr.Column():
                vis_options = gr.CheckboxGroup(
                    [
                        'Model family',
                        'Highlight Pareto',
                        'Show Pareto only',
                        'Log x-axis',
                    ],
                    value=[
                        'Model family',
                    ],
                    label='Visualization options',
                )

        with gr.Row(equal_height=True):
            upper_panel_fig = gr.Plot(show_label=False)

    with gr.Tab("Hardware platform comparison"):
        with gr.Row(equal_height=True):
            with gr.Column():
                comp_hw = gr.Dropdown(
                    [
                        'Jetson Nano (GPU, ONNX Runtime, FP32)',
                        'Raspberry Pi 4 Model B (CPU, TFLite, FP32)',
                        'Raspberry Pi 4 Model B (CPU, ONNX Runtime, FP32)',
                        'Raspberry Pi 5 Model B (CPU, ONNX Runtime, FP32)',
                        'Intel® Core™i7-10875H (CPU, OpenVINO, FP32)',
                        'Khadas VIM3 (NPU, INT16)',
                        'Orange Pi 5 (NPU, FP16)',
                        'NVIDIA A40 (GPU, TensorRT, FP32)',
                    ],
                    value=[
                        'Jetson Nano (GPU, ONNX Runtime, FP32)',
                        'Intel® Core™i7-10875H (CPU, OpenVINO, FP32)',
                    ],
                    label='Hardware',
                    multiselect=True,
                )
            with gr.Column():
                comp_data = gr.Dropdown(
                    choices=['COCO', 'PASCAL VOC', 'SKU-110K', 'WIDERFACE'],
                    value='COCO',
                    label='Dataset',
                )

        with gr.Row(equal_height=True):
            with gr.Column():
                comp_metric = gr.Radio(
                    ['mAP@0.5:0.95', 'mAP@0.5', 'Precision', 'Recall'],
                    value='mAP@0.5:0.95',
                    label='Accuracy metric to plot',
                )

            with gr.Column():
                comp_vis_opt = gr.CheckboxGroup(
                    ['Log x-axis', 'Remove datapoint markers'],
                    value=[
                        'Log x-axis',
                    ],
                    label='Visualization options',
                )

        with gr.Row(equal_height=True):
            comp_plot = gr.Plot(show_label=False)

    gr.Markdown(
        """
        ##
        <span style="font-size:16px">

        Models from this benchmark can be loaded using [Deeplite Torch Zoo](https://github.com/Deeplite/deeplite-torch-zoo) as follows:

        </span>

        ##

        ```python
        from deeplite_torch_zoo import get_model
        model = get_model(
            model_name='yolo4n',        # create a YOLOv4n model for the COCO dataset
            dataset_name='coco',        # (`n` corresponds to width factor 0.25, depth factor 0.33)
            pretrained=False,           #
            custom_head='v8'            # attach a YOLOv8 detection head to YOLOv4n backbone+neck
        )
        ```

        <span style="font-size:16px">

        To train a model, run

        </span>

        ```python
        from deeplite_torch_zoo.trainer import Detector
        model = Detector(torch_model=model)                            # previously created YOLOv4n model
        model.train(data='VOC.yaml', epochs=100, imgsz=480)            # same arguments as the Ultralytics trainer object
        ```

        ##

        <details>
        <summary>Model naming conventions</summary>

        ##

        The model naming convention is that a model named `yolo8d67w25` is a YOLOv8 model with a depth factor of 0.67 and width factor of 0.25. Conventional depth/width factor value namings (n, s, m, l models) are used where possible. YOLOv6(s, m, l) models are considered to be different architectures due to differences other than the depth/width factor value. For every architecture, there are 3 variations in depth factor (0.33, 0.67, 1.0) and 4 variations in width factor (0.25, 0.5, 0.75, 1.0), except for YOLOv7 models, for which only width factor variations are considered while depth is fixed.
        </details>

        ##

        <span style="font-size:20px">
        Pareto-optimal models
        </span>

        ##

        COCO pre-trained models are ready for download.  Other models coming soon!
        """
    )

    table_mode = gr.Radio(
        ['Show top-10 models', 'Show all'],
        value='Show top-10 models',
        label='Pareto model table',
    )

    with gr.Row():
        # pareto_table = gr.DataFrame(interactive=False)
        pareto_table = gr.HTML()

    gr.Markdown(
        """
        ## Citation
        ```
        Accepted at ICCV 2023 Workshop on Resource-Efficient Deep Learning for Computer Vision (RCV'23)
        @article{lazarevich2023yolobench,
            title={YOLOBench: Benchmarking Efficient Object Detectors on Embedded Systems},
            author={Lazarevich, Ivan and Grimaldi, Matteo and Kumar, Ravish and Mitra, Saptarshi and Khan, Shahrukh and Sah, Sudhakar},
            journal={arXiv preprint arXiv:2307.13901},
            year={2023}
        }
        ```
        """
    )

    inputs = [dataset_name, hardware_name, metric_name, vis_options, table_mode]
    inputs_comparison = [comp_data, comp_hw, comp_metric, comp_vis_opt]

    # plot by default (VOC, Raspi4)
    demo.load(
        fn=create_yolobench_plots,
        inputs=inputs,
        outputs=[upper_panel_fig, pareto_table],
    )

    demo.load(
        fn=create_comparison_plot,
        inputs=inputs_comparison,
        outputs=[comp_plot],
    )

    demo.load(
        fn=get_pareto_table,
        inputs=[dataset_name, hardware_name, metric_name],
        outputs=[pareto_table],
    )

    # update in case of dataset selection
    dataset_name.change(
        fn=create_yolobench_plots,
        inputs=inputs,
        outputs=[upper_panel_fig, pareto_table],
    )
    # update in case of metric selection
    metric_name.change(
        fn=create_yolobench_plots,
        inputs=inputs,
        outputs=[upper_panel_fig, pareto_table],
    )

    vis_options.change(
        fn=create_yolobench_plots,
        inputs=inputs,
        outputs=[upper_panel_fig, pareto_table],
    )

    table_mode.change(
        fn=create_yolobench_plots,
        inputs=inputs,
        outputs=[upper_panel_fig, pareto_table],
    )

    # update in case of device selection
    hardware_name.change(
        fn=create_yolobench_plots,
        inputs=inputs,
        outputs=[upper_panel_fig, pareto_table],
    )

    hardware_name.change(
        fn=get_hw_description,
        inputs=[hardware_name],
        outputs=[hardware_desc],
    )

    comp_data.change(
        fn=create_comparison_plot,
        inputs=inputs_comparison,
        outputs=[comp_plot],
    )
    comp_hw.change(
        fn=create_comparison_plot,
        inputs=inputs_comparison,
        outputs=[comp_plot],
    )
    comp_metric.change(
        fn=create_comparison_plot,
        inputs=inputs_comparison,
        outputs=[comp_plot],
    )
    comp_vis_opt.change(
        fn=create_comparison_plot,
        inputs=inputs_comparison,
        outputs=[comp_plot],
    )

if __name__ == "__main__":
    demo.launch()