import gradio as gr
import sahi.utils
from sahi import AutoDetectionModel
import sahi.predict
import sahi.slicing
from PIL import Image
import numpy
from huggingface_hub import hf_hub_download
import torch


IMAGE_SIZE = 640

model_path=hf_hub_download("kadirnar/deprem_model_v1", filename="last.pt",revision="main")


current_device='cuda' if torch.cuda.is_available() else 'cpu'
model_types=["YOLOv5","YOLOv5 + SAHI"]
# Model
model = AutoDetectionModel.from_pretrained(
    model_type="yolov5", model_path=model_path, device=current_device, confidence_threshold=0.5, image_size=IMAGE_SIZE
)


def sahi_yolo_inference(
    model_type,
    image,
    slice_height=1280,
    slice_width=1280,
    overlap_height_ratio=0.1,
    overlap_width_ratio=0.1,
    postprocess_type="NMS",
    postprocess_match_metric="IOU",
    postprocess_match_threshold=0.5,
    postprocess_class_agnostic=False,
):

    #image_width, image_height = image.size
    # sliced_bboxes = sahi.slicing.get_slice_bboxes(
    #     image_height,
    #     image_width,
    #     slice_height,
    #     slice_width,
    #     False,
    #     overlap_height_ratio,
    #     overlap_width_ratio,
    # )
    # if len(sliced_bboxes) > 60:
    #     raise ValueError(
    #         f"{len(sliced_bboxes)} slices are too much for huggingface spaces, try smaller slice size."
    #     )


    rect_th = None or max(round(sum(image.size) / 2 * 0.001), 1)
    text_th = None or max(rect_th - 1, 1)

    if "SAHI" in model_type:
        prediction_result_2 = sahi.predict.get_sliced_prediction(
        image=image,
        detection_model=model,
        slice_height=int(slice_height),
        slice_width=int(slice_width),
        overlap_height_ratio=overlap_height_ratio,
        overlap_width_ratio=overlap_width_ratio,
        postprocess_type=postprocess_type,
        postprocess_match_metric=postprocess_match_metric,
        postprocess_match_threshold=postprocess_match_threshold,
        postprocess_class_agnostic=postprocess_class_agnostic,
        )
        visual_result_2 = sahi.utils.cv.visualize_object_predictions(
            image=numpy.array(image),
            object_prediction_list=prediction_result_2.object_prediction_list,
            rect_th=rect_th,
            text_th=text_th,
        )
        output = Image.fromarray(visual_result_2["image"])
        return output

    else:
        # standard inference
        prediction_result_1 = sahi.predict.get_prediction(
            image=image, detection_model=model
        )
        print(image)
        visual_result_1 = sahi.utils.cv.visualize_object_predictions(
            image=numpy.array(image),
            object_prediction_list=prediction_result_1.object_prediction_list,
            rect_th=rect_th,
            text_th=text_th,
        )
        output = Image.fromarray(visual_result_1["image"])
        return output

    # sliced inference




inputs = [
    gr.inputs.Dropdown(choices=model_types,label="Choose Model Type",type="value",),
    gr.inputs.Image(type="pil", label="Original Image"),
    gr.inputs.Number(default=1920 label="slice_height"),
    gr.inputs.Number(default=1920, label="slice_width"),
    gr.inputs.Number(default=0.1, label="overlap_height_ratio"),
    gr.inputs.Number(default=0.1, label="overlap_width_ratio"),
    gr.inputs.Dropdown(
        ["NMS", "GREEDYNMM"],
        type="value",
        default="NMS",
        label="postprocess_type",
    ),
    gr.inputs.Dropdown(
        ["IOU", "IOS"], type="value", default="IOU", label="postprocess_type"
    ),
    gr.inputs.Number(default=0.5, label="postprocess_match_threshold"),
    gr.inputs.Checkbox(default=True, label="postprocess_class_agnostic"),
]

outputs = [
    gr.outputs.Image(type="pil", label="Output")
]

title = "Small Object Detection with SAHI + YOLOv5"
description = "SAHI + YOLOv5 demo for small object detection. Upload an image or click an example image to use."
article = "<p style='text-align: center'>SAHI is a lightweight vision library for performing large scale object detection/ instance segmentation.. <a href='https://github.com/obss/sahi'>SAHI Github</a> | <a href='https://medium.com/codable/sahi-a-vision-library-for-performing-sliced-inference-on-large-images-small-objects-c8b086af3b80'>SAHI Blog</a> | <a href='https://github.com/fcakyon/yolov5-pip'>YOLOv5 Github</a> </p>"
examples = [
    [model_types[0],"26.jpg", 256, 256, 0.2, 0.2, "NMS", "IOU", 0.5, True],
    [model_types[0],"27.jpg", 512, 512, 0.2, 0.2, "NMS", "IOU", 0.5, True],
    [model_types[0],"28.jpg", 512, 512, 0.2, 0.2, "NMS", "IOU", 0.5, True],
    [model_types[0],"31.jpg", 512, 512, 0.2, 0.2, "NMS", "IOU", 0.5, True],

]
gr.Interface(
    sahi_yolo_inference,
    inputs,
    outputs,
    title=title,
    description=description,
    article=article,
    examples=examples,
    theme="huggingface",
).launch(debug=True, enable_queue=True)